mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
21 Commits
v0.11.0rc3
...
v0.11.0
Author | SHA1 | Date | |
---|---|---|---|
b8b302cde4 | |||
f71952c1c4 | |||
d1007767c5 | |||
c75c2e70d6 | |||
9d9a2b77f1 | |||
6040e0b6c0 | |||
05bf0c52a1 | |||
c536881a7c | |||
ebce361c07 | |||
e4beabd2c8 | |||
febb688356 | |||
a1825fe645 | |||
bab9231bf1 | |||
c214d699fd | |||
c3dfb0f6dd | |||
83f3c9beae | |||
d0b178cef1 | |||
b3230e1ac0 | |||
03df0fb5d2 | |||
9471879bd4 | |||
ab5b6459df |
@ -48,7 +48,7 @@ steps:
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
|
@ -584,8 +584,9 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
|
@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
|
||||
set(SUPPORT_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
|
||||
list(APPEND SUPPORT_ARCHS 9.0a)
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
|
||||
list(APPEND SUPPORT_ARCHS 10.0a)
|
||||
endif()
|
||||
|
||||
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||
if(FLASH_MLA_ARCHS)
|
||||
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
|
||||
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc)
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_Extension_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_C
|
||||
DESTINATION vllm
|
||||
@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_extension_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${FlashMLA_Extension_SOURCES}
|
||||
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_extension_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
else()
|
||||
# Create an empty target for setup.py when not targeting sm90a systems
|
||||
# Create empty targets for setup.py when not targeting sm90a systems
|
||||
add_custom_target(_flashmla_C)
|
||||
add_custom_target(_flashmla_extension_C)
|
||||
endif()
|
||||
|
||||
|
@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_page_table(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
pipeline_page_table, pipeline_pt_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_cpasync(
|
||||
blk_coord,
|
||||
@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
params.mainloop_params,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv,
|
||||
local_split_kv,
|
||||
/* must be shared pipe */
|
||||
pipeline_page_table, pipeline_pt_consumer_state
|
||||
);
|
||||
@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_tma</* paged= */ true>(
|
||||
blk_coord,
|
||||
@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||
}
|
||||
@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_tma<false>(
|
||||
blk_coord,
|
||||
@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||
}
|
||||
@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
mma(blk_coord,
|
||||
problem_shape,
|
||||
@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_producer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_consumer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto split_kv = params.split_kv;
|
||||
auto local_split_kv = split_kv;
|
||||
auto split_kv = params.split_kv;
|
||||
auto local_split_kv = split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
compute(
|
||||
blk_coord,
|
||||
@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_producer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
|
||||
@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
cutlass::arch::NamedBarrier(
|
||||
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
|
||||
kNamedBarrierEpilogue
|
||||
).arrive();
|
||||
).arrive_and_wait();
|
||||
|
||||
return;
|
||||
}
|
||||
|
@ -56,3 +56,11 @@ void cp_gather_cache(
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// Indexer K quantization and cache function
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt);
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat> // FLT_MIN
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
@ -396,6 +397,180 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_ds_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int64_t dst_idx_start =
|
||||
block_idx * block_stride + block_offset * entry_stride;
|
||||
|
||||
// Create 4 tile scales in shared memory
|
||||
__shared__ float smem[20];
|
||||
float* shard_abs_max = smem;
|
||||
float* tile_scales = smem + 16;
|
||||
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
|
||||
// Cast kv_cache to 16_bit for RoPE values
|
||||
scalar_t* kv_cache_16bit =
|
||||
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
|
||||
|
||||
// The last 64 threads handle the RoPE part
|
||||
if (threadIdx.x >= kv_lora_rank) {
|
||||
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
|
||||
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
|
||||
// RoPE values start after the packed 8-bit NoPE values and the
|
||||
// 32-bit scales
|
||||
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
|
||||
kv_cache_16bit[dst_idx] = k_pe[src_idx];
|
||||
return;
|
||||
}
|
||||
|
||||
// Determine the scale for each chunk of NoPE
|
||||
const int16_t tile_idx = threadIdx.x >> 7;
|
||||
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
|
||||
const int16_t lane_idx = threadIdx.x & 31;
|
||||
|
||||
// Load the NoPE element for this thread into registers
|
||||
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
|
||||
const scalar_t src_val = kv_c[src_idx];
|
||||
|
||||
// Warp-level reduction to find the max absolute value in the warp
|
||||
float max_abs = fabsf(src_val);
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
|
||||
#else
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
// The first lane of each warp in each tile writes the max_abs of this part
|
||||
// of the tile to shared memory
|
||||
if (lane_idx == 0) {
|
||||
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The first lane of the first warp in each tile computes the scale for the
|
||||
// tile and writes it to shared memory and to kv_cache
|
||||
if (warp_idx == 0 && lane_idx == 0) {
|
||||
float4 shard_abs_max_vec =
|
||||
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
|
||||
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
|
||||
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
|
||||
448.f;
|
||||
|
||||
// Avoid division by zero in `scaled_convert`
|
||||
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
|
||||
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
|
||||
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
|
||||
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now all threads in the block scale and write their element
|
||||
const float scale_val = tile_scales[tile_idx];
|
||||
const int64_t dst_idx = dst_idx_start + threadIdx.x;
|
||||
kv_cache[dst_idx] =
|
||||
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
|
||||
src_val, scale_val);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void indexer_k_quant_and_cache_kernel(
|
||||
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int head_dim, // dimension of each head
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
|
||||
threadIdx.y * blockDim.x + threadIdx.x) *
|
||||
VEC_SIZE;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
const int64_t block_idx = slot_idx / cache_block_size;
|
||||
const int64_t block_offset = slot_idx % cache_block_size;
|
||||
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
|
||||
return;
|
||||
}
|
||||
|
||||
float2 k_val = (reinterpret_cast<const float2*>(
|
||||
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
|
||||
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
// Reduced amax
|
||||
for (int mask = 16; mask > 0; mask /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
|
||||
#else
|
||||
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
||||
#endif
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
}
|
||||
|
||||
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
|
||||
block_offset * head_dim + head_dim_idx;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
kv_cache[dst_offset + i] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t dst_scale_idx =
|
||||
block_idx * cache_block_size * cache_stride +
|
||||
cache_block_size * head_dim +
|
||||
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
|
||||
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -438,7 +613,7 @@ void reshape_and_cache(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
CALL_RESHAPE_AND_CACHE);
|
||||
}
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -509,6 +684,18 @@ void reshape_and_cache_flash(
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
@ -531,20 +718,44 @@ void concat_and_cache_mla(
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
|
||||
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
|
||||
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_c.itemsize() == 2,
|
||||
"kv_c.itemsize() must be 2 for fp8_ds_mla");
|
||||
TORCH_CHECK(k_pe.itemsize() == 2,
|
||||
"k_pe.itemsize() must be 2 for fp8_ds_mla");
|
||||
} else {
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
}
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
dim3 grid(num_tokens);
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
dim3 block(576);
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_DS_MLA);
|
||||
} else {
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@ -922,3 +1133,42 @@ void cp_gather_cache(
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(k.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
|
||||
cache_block_size, cache_stride, use_ue8m0);
|
||||
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt) {
|
||||
int num_tokens = k.size(0);
|
||||
int head_dim = k.size(1);
|
||||
int cache_block_size = kv_cache.size(1);
|
||||
int cache_stride = kv_cache.size(2);
|
||||
bool use_ue8m0 = scale_fmt == "ue8m0";
|
||||
|
||||
TORCH_CHECK(k.device() == kv_cache.device(),
|
||||
"k and kv_cache must be on the same device");
|
||||
TORCH_CHECK(k.device() == slot_mapping.device(),
|
||||
"k and slot_mapping must be on the same device");
|
||||
TORCH_CHECK(head_dim % quant_block_size == 0,
|
||||
"head_dim must be divisible by quant_block_size");
|
||||
|
||||
constexpr int vec_size = 4;
|
||||
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
|
||||
(quant_block_size * vec_size));
|
||||
dim3 block(32, vec_size);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
|
||||
CALL_INDEXER_K_QUANT_AND_CACHE);
|
||||
}
|
||||
|
@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_ds_mla") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
|
@ -713,6 +713,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
||||
"slot_mapping, "
|
||||
"int quant_block_size, str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
|
||||
&indexer_k_quant_and_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
@ -14,6 +14,11 @@ ARG PYTHON_VERSION=3.12
|
||||
#
|
||||
# Example:
|
||||
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
|
||||
# Important: We build with an old version of Ubuntu to maintain broad
|
||||
# compatibility with other Linux OSes. The main reason for this is that the
|
||||
# glibc version is baked into the distro, and binaries built with one glibc
|
||||
# version are not backwards compatible with OSes that use an earlier version.
|
||||
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
@ -75,34 +80,19 @@ ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Install Python and other dependencies
|
||||
# Install system dependencies and uv, then create Python virtual environment
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
||||
fi ; \
|
||||
else \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done ; \
|
||||
fi \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo python3-pip \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
|
||||
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
|
||||
&& ln -s /opt/venv/bin/python3 /usr/bin/python3 \
|
||||
&& ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
|
||||
&& ln -s /opt/venv/bin/pip /usr/bin/pip \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
@ -111,9 +101,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
# Activate virtual environment and add uv to PATH
|
||||
ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
@ -142,7 +132,7 @@ WORKDIR /workspace
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/cuda.txt \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# cuda arch list used by torch
|
||||
@ -172,7 +162,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
COPY . .
|
||||
@ -269,7 +259,7 @@ COPY requirements/lint.txt requirements/lint.txt
|
||||
COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/dev.txt \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
|
@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0
|
||||
#
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
|
||||
ARG CUDA_VERSION=12.8.0
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG TARGETPLATFORM
|
||||
|
@ -8,6 +8,9 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
|
||||
|
||||
!!! tip
|
||||
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
||||
|
||||
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions.
|
||||
|
||||
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
||||
|
||||
## Offline Inference
|
||||
|
@ -66,6 +66,9 @@ Restrict domains that vLLM can access for media URLs by setting
|
||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||
|
||||
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP
|
||||
redirects from being followed to bypass domain restrictions.
|
||||
|
||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||
|
||||
While vLLM is designed to allow unsafe network services to be isolated to
|
||||
|
@ -54,6 +54,7 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@ -118,9 +119,9 @@ def main(args):
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method.endswith("mtp"):
|
||||
elif args.method == "mtp":
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"method": "mtp",
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
else:
|
||||
|
4
setup.py
4
setup.py
@ -322,6 +322,8 @@ class precompiled_wheel_utils:
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/_flashmla_extension_C.abi3.so",
|
||||
"vllm/_sparse_flashmla_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
@ -589,6 +591,8 @@ if _is_cuda():
|
||||
# not targeting a hopper system
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
|
@ -191,7 +191,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
),
|
||||
layer_names=[self.attn.layer_name],
|
||||
vllm_config=self.vllm_config,
|
||||
|
@ -593,6 +593,119 @@ def test_concat_and_cache_mla(
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_ds_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if dtype.itemsize != 2:
|
||||
pytest.skip("ds_mla only supports 16-bit input")
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
|
||||
|
||||
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device)
|
||||
|
||||
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
|
||||
tile_data = torch.zeros(128, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
ref_cache_16bit = ref_cache_slice.view(dtype)
|
||||
ref_cache_32bit = ref_cache_slice.view(torch.float32)
|
||||
|
||||
kv_c_data = kv_c[i]
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = (tile_idx + 1) * 128
|
||||
tile_data[:] = kv_c_data[tile_start:tile_end]
|
||||
|
||||
# tile_scale = tile_data.amax().to(torch.float32) / 448.
|
||||
# NOTE: Using torch's amax() gives different results,
|
||||
# so this must be manually computed.
|
||||
tile_data_float = tile_data.to(torch.float32)
|
||||
manual_max = abs(tile_data_float[0])
|
||||
for j in range(1, 128):
|
||||
manual_max = max(manual_max, abs(tile_data_float[j]))
|
||||
tile_scale = manual_max / 448.
|
||||
|
||||
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
|
||||
|
||||
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
|
||||
tile_data,
|
||||
tile_scale.item(),
|
||||
kv_dtype="fp8")
|
||||
|
||||
for j in range(qk_rope_head_dim):
|
||||
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_cache_slice = kv_cache[block_idx, block_offset]
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
|
||||
kv_nope = kv_cache_slice[:kv_lora_rank]
|
||||
ref_nope = ref_cache_slice[:kv_lora_rank]
|
||||
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
ref_scales = ref_cache_slice.view(
|
||||
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
|
||||
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
|
||||
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
|
279
tests/kernels/attention/test_deepgemm_attention.py
Normal file
279
tests/kernels/attention/test_deepgemm_attention.py
Normal file
@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits, get_num_sms,
|
||||
get_paged_mqa_logits_metadata)
|
||||
|
||||
|
||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (num_blocks, block_size, 1, head_dim)
|
||||
num_blocks, block_size, num_heads, head_dim = x.shape
|
||||
assert num_heads == 1
|
||||
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
x_fp8 = torch.empty(
|
||||
(num_blocks, block_size * (head_dim + 4)),
|
||||
device=x.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
x_fp8[:, :block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
|
||||
x_fp8[:,
|
||||
block_size * head_dim:] = sf.view(num_blocks,
|
||||
block_size).view(dtype=torch.uint8)
|
||||
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
||||
|
||||
|
||||
def per_custom_dims_cast_to_fp8(
|
||||
x: torch.Tensor, dims: tuple,
|
||||
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
||||
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled, sf.squeeze()
|
||||
|
||||
|
||||
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
|
||||
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
|
||||
chunk_size = seq_len // 2
|
||||
cp_size = seq_len_kv // seq_len
|
||||
cp_id = cp_size // 3
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
for i in range(chunk_size):
|
||||
ke[i] = cp_id * chunk_size + i
|
||||
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
|
||||
return ks, ke
|
||||
|
||||
|
||||
def _ref_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
):
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
def test_deepgemm_fp8_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
num_heads, head_dim = 32, 128
|
||||
for seq_len in (512, ):
|
||||
for seq_len_kv in (1024, ):
|
||||
for disable_cp in (False, True):
|
||||
q = torch.randn(
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv = torch.randn(seq_len_kv,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16)
|
||||
weights = torch.randn(seq_len,
|
||||
num_heads,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
if disable_cp:
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.arange(seq_len, dtype=torch.int,
|
||||
device="cuda") + (seq_len_kv - seq_len)
|
||||
else:
|
||||
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
|
||||
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
q=q,
|
||||
kv=kv,
|
||||
weights=weights,
|
||||
cu_seqlen_ks=ks,
|
||||
cu_seqlen_ke=ke,
|
||||
)
|
||||
|
||||
ref_neginf_mask = ref_logits == float("-inf")
|
||||
neginf_mask = logits == float("-inf")
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
|
||||
|
||||
def _ref_fp8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
batch_size, next_n, _, _ = q.size()
|
||||
_, block_size, _, _ = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens_list = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens_list[i]
|
||||
q_offsets = torch.arange(context_len - next_n,
|
||||
context_len,
|
||||
device="cuda")
|
||||
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
|
||||
0, 1).contiguous())
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size,
|
||||
(block_rk + 1) * block_size,
|
||||
device="cuda",
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
|
||||
<= q_offsets[:, None])
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n:(i + 1) * next_n,
|
||||
block_rk * block_size:(block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
def test_deepgemm_fp8_paged_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
max_model_len = 4096
|
||||
for batch_size, next_n in [(4, 1), (2, 2)]:
|
||||
for heads, index_dim in [(32, 128)]:
|
||||
for avg_kv in (2048, ):
|
||||
num_blocks, blocksize = max_model_len * 2, 64
|
||||
|
||||
q = torch.randn(
|
||||
(batch_size, next_n, heads, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv_cache = torch.randn(
|
||||
(num_blocks, blocksize, 1, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
weights = torch.randn(
|
||||
(batch_size * next_n, heads),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
context_lens = (torch.randint(int(0.8 * avg_kv),
|
||||
int(1.2 * avg_kv),
|
||||
(batch_size, )).cuda().to(
|
||||
torch.int32))
|
||||
max_block_len = ((context_lens.max().item() + blocksize - 1) //
|
||||
blocksize * blocksize)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, max_block_len),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
counter = 0
|
||||
block_idx_pool = list(range(num_blocks))
|
||||
random.shuffle(block_idx_pool)
|
||||
for i in range(batch_size):
|
||||
ctx_len = int(context_lens[i].item())
|
||||
for j in range((ctx_len + blocksize - 1) // blocksize):
|
||||
block_tables[i][j] = block_idx_pool[counter]
|
||||
counter += 1
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = get_paged_mqa_logits_metadata(
|
||||
context_lens, blocksize, get_num_sms())
|
||||
logits = fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_paged_mqa_logits(
|
||||
q,
|
||||
kv_cache,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
positions = (torch.arange(max_model_len,
|
||||
device="cuda").unsqueeze(0).expand(
|
||||
batch_size * next_n, -1))
|
||||
row_indices = (
|
||||
torch.arange(batch_size * next_n, device="cuda") // next_n)
|
||||
next_n_offset = (
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n)
|
||||
mask = positions <= (context_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
|
||||
logits = logits.masked_fill(~mask, 0)
|
||||
ref_logits = ref_logits.masked_fill(~mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
descale_k = None
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
return flash_mla_with_kvcache(q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
|
119
tests/kernels/attention/test_flashmla_sparse.py
Normal file
119
tests/kernels/attention/test_flashmla_sparse.py
Normal file
@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _cuda_sm90_available() -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return major == 9
|
||||
|
||||
|
||||
def test_sparse_flashmla_metadata_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 128
|
||||
num_heads_k = 1
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
topk = 128
|
||||
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
assert tile_md.dtype == torch.int32
|
||||
assert num_splits.dtype == torch.int32
|
||||
|
||||
|
||||
def test_sparse_flashmla_decode_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 1
|
||||
head_dim_k = 576
|
||||
head_dim_v = 512
|
||||
num_heads_k = 1
|
||||
page_block_size = 64
|
||||
bytes_per_token = 656
|
||||
topk = 128
|
||||
|
||||
# Metadata
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
# q_heads_per_hk = num_heads_q // num_heads_k
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
|
||||
# Inputs
|
||||
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
|
||||
dtype=torch.bfloat16,
|
||||
device=device)
|
||||
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
indices = torch.zeros((batch_size, seqlen_q, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
block_table = torch.zeros((batch_size, 128),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
out, lse = fm.flash_mla_with_kvcache(q,
|
||||
k_cache,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
head_dim_v,
|
||||
tile_md,
|
||||
num_splits,
|
||||
indices=indices,
|
||||
is_fp8_kvcache=True)
|
||||
assert out.shape[0] == batch_size
|
||||
assert out.shape[-1] == head_dim_v
|
||||
assert lse.shape[0] == batch_size
|
||||
|
||||
|
||||
def test_sparse_flashmla_prefill_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
s_q = 1
|
||||
s_kv = 1
|
||||
h_q = 64 # kernel expects multiple of 64
|
||||
h_kv = 1
|
||||
d_qk = 576
|
||||
d_v = 512
|
||||
topk = 128
|
||||
|
||||
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
|
||||
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
|
||||
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
|
||||
|
||||
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
|
||||
d_v)
|
||||
assert out.shape == (s_q, h_q, d_v)
|
||||
assert max_logits.shape == (s_q, h_q)
|
||||
assert lse.shape == (s_q, h_q)
|
245
tests/kernels/attention/test_pack_unpack_triton.py
Normal file
245
tests/kernels/attention/test_pack_unpack_triton.py
Normal file
@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
|
||||
|
||||
def test_pack_seq_basic_fp8():
|
||||
"""Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors (N, H, D)
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4)
|
||||
(10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8)
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32)
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check output shape and properties
|
||||
expected_shape = (B, max(lengths_list), H, D)
|
||||
assert packed.shape == expected_shape
|
||||
assert packed.dtype == dtype
|
||||
assert packed.device == x.device
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = sum(lengths_list[:b])
|
||||
seq_len = lengths_list[b]
|
||||
|
||||
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
|
||||
actual_data = packed[b, :seq_len].to(torch.float32)
|
||||
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_custom_padding_fp8():
|
||||
"""Test pack_seq_triton with custom padding values for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test with different padding values
|
||||
for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
|
||||
result = pack_seq_triton(x, lengths, pad_value=pad_value)
|
||||
|
||||
# Check valid data
|
||||
for b in range(B):
|
||||
start_idx = b * 10
|
||||
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
|
||||
actual_data = result[b, :10].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
# Check padding (fp8 has limited range, so check for large values)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
if pad_value < 0:
|
||||
assert torch.all(padded_data < -50) # Large negative values
|
||||
elif pad_value > 0:
|
||||
assert torch.all(padded_data > 50) # Large positive values
|
||||
else:
|
||||
assert torch.allclose(padded_data,
|
||||
torch.zeros_like(padded_data),
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_default_negative_inf_padding_fp8():
|
||||
"""Test that pack_seq_triton uses -inf padding by default for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
# B = 2
|
||||
N, H, D = 20, 8, 16
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check that padding is large negative values (fp8 representation of -inf)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
assert torch.all(
|
||||
padded_data < -100) # fp8 -inf is represented as large negative number
|
||||
|
||||
|
||||
def test_pack_seq_edge_cases_fp8():
|
||||
"""Test pack_seq_triton with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (1, 10, 8, 16)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 1, 4, 8)
|
||||
|
||||
# Test with different sequence lengths
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 7, 8, 16)
|
||||
|
||||
|
||||
def test_pack_seq_different_block_sizes_fp8():
|
||||
"""Test pack_seq_triton with different block sizes for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 100, 16, 32, 4
|
||||
lengths = torch.tensor([25, 25, 25, 25], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test different block sizes
|
||||
for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
|
||||
result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
|
||||
|
||||
assert result.shape == (B, 25, H, D)
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = b * 25
|
||||
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
|
||||
actual_data = result[b, :25].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_shape_consistency():
|
||||
"""Test that pack_seq_triton maintains shape consistency."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check shape consistency
|
||||
assert result.shape[0] == B # Batch dimension
|
||||
assert result.shape[1] == lengths.max().item() # Max sequence length
|
||||
assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved
|
||||
|
||||
|
||||
def test_pack_unpack_roundtrip_fp8():
|
||||
"""Test that pack -> unpack gives us back the original data for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]),
|
||||
(10, 4, 8, 3, [2, 4, 4]),
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]),
|
||||
(15, 8, 16, 3, [7, 5, 3]),
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Unpack the data
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
|
||||
# Check that we get back the original data (within fp8 precision)
|
||||
assert unpacked.shape == x.shape
|
||||
x_f32 = x.to(torch.float32)
|
||||
unpacked_f32 = unpacked.to(torch.float32)
|
||||
assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Unpack without explicit start locations (computed in kernel)
|
||||
unpacked_with_loc = unpack_seq_triton(packed, lengths)
|
||||
assert_close(x_f32,
|
||||
unpacked_with_loc.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
def test_unpack_seq_triton_edge_cases_fp8():
|
||||
"""Test unpack function with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
# Only compare the first 3 elements that were actually packed
|
||||
assert_close(x[:3].to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
@ -207,6 +207,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
|
||||
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
|
||||
min_transformers_version="4.54"),
|
||||
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
|
@ -8,7 +8,8 @@ import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
@ -62,11 +63,13 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
# Avoid calling model.forward()
|
||||
def _initialize_kv_caches_v1(self, vllm_config):
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
scheduler_kv_cache_config = get_kv_cache_configs(
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config,
|
||||
kv_cache_specs,
|
||||
[10 * GiB_bytes],
|
||||
)[0]
|
||||
)
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
return 1, 0, scheduler_kv_cache_config
|
||||
|
@ -26,5 +26,5 @@ class DummyPlatform(Platform):
|
||||
|
||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink):
|
||||
has_sink, use_sparse):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -10,6 +11,7 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache(
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True) -> torch.Tensor:
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
|
||||
"""Create and prepopulate an MLA KV cache with context data.
|
||||
|
||||
Args:
|
||||
@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache(
|
||||
common_attn_metadata: Common attention metadata
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
kv_cache_dtype: Optional kv cache dtype string. When set to
|
||||
"fp8_ds_mla" the cache is populated using the
|
||||
fp8 DeepSeek MLA layout via concat_and_cache_mla.
|
||||
scale: Scaling factor forwarded to concat_and_cache_mla when the
|
||||
fp8 cache layout is requested.
|
||||
|
||||
Returns:
|
||||
MLA KV cache tensor
|
||||
@ -105,23 +114,61 @@ def create_and_prepopulate_kv_cache(
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
if not kv_c_contexts:
|
||||
raise ValueError("kv_c_contexts cannot be empty when using"
|
||||
" fp8_ds_mla cache dtype")
|
||||
kv_lora_rank = kv_c_contexts[0].shape[-1]
|
||||
rope_dim = k_pe_contexts[0].shape[-1]
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
scale_tensor = (scale
|
||||
if isinstance(scale, torch.Tensor) else torch.tensor(
|
||||
scale, dtype=torch.float32, device=device))
|
||||
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
# Start from block_id=1 since block_id=0 is considered the null block
|
||||
start_block_idx = 1
|
||||
for i in range(batch_size):
|
||||
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
|
||||
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
context_len = kv_c_context.shape[0]
|
||||
if context_len == 0:
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
continue
|
||||
|
||||
start = start_block_idx * block_size
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
slots = torch.arange(context_len, device=device,
|
||||
dtype=torch.long) + start
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_context,
|
||||
k_pe_context.squeeze(1),
|
||||
kv_cache,
|
||||
slots,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale_tensor,
|
||||
)
|
||||
else:
|
||||
kv_context = torch.cat(
|
||||
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
# Stay block aligned and allocate enough blocks for the new tokens
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
|
448
tests/v1/attention/test_sparse_mla_backends.py
Normal file
448
tests/v1/attention/test_sparse_mla_backends.py
Normal file
@ -0,0 +1,448 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for the FlashMLA sparse backend utilities."""
|
||||
|
||||
import math
|
||||
from types import MethodType, SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_mla_backends import (
|
||||
BATCH_SPECS, BatchSpec, MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache)
|
||||
from tests.v1.attention.utils import (create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl, FlashMLASparseMetadata)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
name: BATCH_SPECS[name]
|
||||
for name in [
|
||||
"mixed_small",
|
||||
"mixed_medium",
|
||||
"small_prefill",
|
||||
"medium_prefill",
|
||||
"single_prefill",
|
||||
]
|
||||
}
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
|
||||
query_lens=[256] * 2)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2)
|
||||
|
||||
|
||||
def _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
|
||||
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
|
||||
|
||||
# The first kv_lora_rank bytes store FP8 latent values with one scale per
|
||||
# 128 element tile written as float32 right after the latent payload.
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank,
|
||||
dtype=torch.float16,
|
||||
device=cache_slice.device)
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = tile_start + 128
|
||||
ops.convert_fp8(latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8")
|
||||
latent = latent.to(dtype)
|
||||
|
||||
rope_offset = kv_lora_rank // 2 + 8
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
|
||||
return latent, rope_vals.clone()
|
||||
|
||||
|
||||
def _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
|
||||
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
|
||||
|
||||
if kv_c.numel() == 0:
|
||||
return kv_c.clone(), k_pe.clone()
|
||||
|
||||
kv_lora_rank = kv_c.shape[-1]
|
||||
rope_dim = k_pe.shape[-1]
|
||||
num_tokens = kv_c.shape[0]
|
||||
num_blocks = max(1, math.ceil(num_tokens / block_size))
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
|
||||
tmp_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=kv_c.device)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.long,
|
||||
device=kv_c.device)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c,
|
||||
k_pe,
|
||||
tmp_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale)
|
||||
|
||||
dequant_kv_c = torch.empty_like(kv_c)
|
||||
dequant_k_pe = torch.empty_like(k_pe)
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
cache_slice = tmp_cache[block_idx, block_offset]
|
||||
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
|
||||
dequant_kv_c[token_idx] = latent
|
||||
dequant_k_pe[token_idx] = rope_vals
|
||||
|
||||
return dequant_kv_c, dequant_k_pe
|
||||
|
||||
|
||||
def test_sparse_backend_metadata_registration():
|
||||
backend = FlashMLASparseBackend
|
||||
|
||||
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
|
||||
assert backend.get_metadata_cls() is FlashMLASparseMetadata
|
||||
assert backend.get_impl_cls() is FlashMLASparseImpl
|
||||
|
||||
dtype_list = backend.get_supported_dtypes()
|
||||
assert torch.bfloat16 in dtype_list
|
||||
|
||||
shape = backend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=64,
|
||||
num_kv_heads=1,
|
||||
head_size=576)
|
||||
assert shape == (2, 64, 576)
|
||||
|
||||
|
||||
def test_sparse_decode_metadata_filters_prefill_indices():
|
||||
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
|
||||
metadata = FlashMLASparseDecodeAndContextMetadata(
|
||||
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
|
||||
num_splits=torch.tensor([1, 1], dtype=torch.int32),
|
||||
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
|
||||
prefill_context_lengths=prefill_context_lengths,
|
||||
)
|
||||
|
||||
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
|
||||
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(
|
||||
indices)
|
||||
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
|
||||
dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
|
||||
dtype=torch.int32)
|
||||
|
||||
assert torch.equal(context_indices, expected_context)
|
||||
assert torch.equal(new_token_indices, expected_new_tokens)
|
||||
|
||||
|
||||
def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
|
||||
dummy_layer = object()
|
||||
q = torch.zeros((2, 1, 3))
|
||||
k_c = torch.zeros((2, 3))
|
||||
k_pe = torch.zeros((2, 1, 1))
|
||||
kv_cache = torch.zeros((1, 1, 1))
|
||||
output = torch.ones((2, 4))
|
||||
|
||||
result = FlashMLASparseImpl.forward(impl,
|
||||
dummy_layer,
|
||||
q,
|
||||
k_c,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata=None,
|
||||
output=output)
|
||||
|
||||
assert result is output
|
||||
assert torch.all(result == 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
kv_cache_dtype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
|
||||
# Model hyper-parameters (kept intentionally small for the unit test)
|
||||
num_heads = 128
|
||||
kv_lora_rank = 512
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
v_head_dim = 128
|
||||
head_size = kv_lora_rank + qk_rope_head_dim
|
||||
topk_tokens = 2048
|
||||
|
||||
max_seqlen = max(batch_spec.seq_lens)
|
||||
total_cache_tokens = sum(batch_spec.seq_lens)
|
||||
block_size = 64
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048,
|
||||
cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size)
|
||||
model_config = vllm_config.model_config
|
||||
model_config.hf_config = SimpleNamespace(
|
||||
attn_module_list_cfg=[{
|
||||
"topk_tokens": topk_tokens
|
||||
}])
|
||||
model_config.hf_text_config = SimpleNamespace(
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
model_type="deepseek_v2",
|
||||
)
|
||||
model_config.dtype = dtype
|
||||
model_config.get_num_attention_heads = MethodType(
|
||||
lambda self, parallel_config: num_heads, model_config)
|
||||
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
|
||||
model_config)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size,
|
||||
model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None,
|
||||
model_config)
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = 1.0 / math.sqrt(head_size)
|
||||
|
||||
# Shared MLA projection weights to keep reference and backend in sync
|
||||
W_UK = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UV = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# Build synthetic decode-only workload
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
reference_outputs = []
|
||||
|
||||
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
|
||||
q_c = torch.rand(q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len,
|
||||
1,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, ctx_len:] = causal_mask
|
||||
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
all_k_pe_vllm.append(k_pe_full[ctx_len:])
|
||||
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
|
||||
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_reference = torch.cat(reference_outputs, dim=0)
|
||||
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
vllm_config.cache_config.block_size,
|
||||
device,
|
||||
arange_block_indices=True)
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device,
|
||||
dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = (debug_indices <= token_positions)
|
||||
debug_indices = torch.where(causal_mask, debug_indices,
|
||||
torch.full_like(debug_indices, -1))
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
|
||||
-1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
|
||||
output_size=num_heads *
|
||||
(qk_nope_head_dim + v_head_dim),
|
||||
bias=False).to(device=device,
|
||||
dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(metadata.num_actual_tokens,
|
||||
num_heads * v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
backend_output = impl.forward(layer,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
assert backend_output.dtype == sdpa_reference.dtype
|
||||
assert torch.isfinite(backend_output).all()
|
||||
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_reference,
|
||||
rtol=0.5,
|
||||
atol=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,max_buf,start,expected",
|
||||
[
|
||||
# Basic split: totals per chunk ≤ max_buf
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Non-zero start index
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would
|
||||
# overflow
|
||||
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
|
||||
# All requests fit in a single chunk
|
||||
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
|
||||
# Large buffer with non-zero start
|
||||
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
|
||||
],
|
||||
)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf, start)
|
||||
assert out == expected
|
@ -168,7 +168,6 @@ def create_standard_kv_cache_spec(
|
||||
vllm_config.parallel_config),
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
use_mla=vllm_config.model_config.use_mla,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,8 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec,
|
||||
KVCacheTensor, MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
sliding_window=None):
|
||||
return FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
sliding_window=1):
|
||||
return SlidingWindowSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@ -894,7 +891,6 @@ def test_merge_kv_cache_spec():
|
||||
num_kv_heads=full_spec.num_kv_heads,
|
||||
head_size=full_spec.head_size,
|
||||
dtype=full_spec.dtype,
|
||||
use_mla=full_spec.use_mla,
|
||||
sliding_window=1,
|
||||
),
|
||||
]
|
||||
@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len,
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
# Estimate the maximum model length, 16384 model_len need 8GB
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
sliding_window=1024,
|
||||
)
|
||||
|
||||
@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config():
|
||||
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def new_mla_spec(cache_dtype_str=None):
|
||||
return MLAAttentionSpec(block_size=16,
|
||||
num_kv_heads=16,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
cache_dtype_str=cache_dtype_str)
|
||||
|
||||
|
||||
def test_merge_mla_spec():
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(),
|
||||
new_mla_spec(),
|
||||
]
|
||||
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
|
||||
assert mla_spec == new_mla_spec()
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
]
|
||||
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
|
||||
assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla")
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_mla_spec(cache_dtype_str=None),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
kv_cache_specs = [
|
||||
new_kv_cache_spec(),
|
||||
new_mla_spec(),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_kv_cache_spec(),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
@ -98,7 +98,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
@ -107,7 +106,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
],
|
||||
@ -1338,7 +1336,6 @@ def test_eagle_with_sliding_window():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
use_mla=False,
|
||||
)
|
||||
manager = KVCacheManager(
|
||||
KVCacheConfig(
|
||||
|
@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -100,7 +99,6 @@ def test_sliding_window_possible_cached_prefix():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -165,7 +163,6 @@ def test_chunked_local_attention_remove_skipped_blocks():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
@ -217,7 +214,6 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
@ -285,7 +281,6 @@ def test_get_num_blocks_to_allocate():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -308,7 +303,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
|
@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
@ -222,3 +224,66 @@ def test_eagle_correctness(
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"])
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, int],
|
||||
mm_enabled: bool,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using MTP speculative decoding.
|
||||
model_setup: (method, model_name, tp_size)
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"num_speculative_tokens": 1,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 80% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
@ -836,8 +836,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
mock_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
dtype=torch.float16)
|
||||
|
||||
mock_executor.get_kv_cache_specs.return_value = [{
|
||||
"default": mock_spec
|
||||
|
@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
# gpu_model_runner. Here we just hardcode some dummy values.
|
||||
self.slot_size_bytes = 4096
|
||||
self.block_len = self.slot_size_bytes * self.block_size
|
||||
slot_size_bytes = 4096
|
||||
self.slot_size_per_layer = [slot_size_bytes]
|
||||
self.block_len_per_layer = [slot_size_bytes * self.block_size]
|
||||
self.num_blocks = 1
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=self.block_len,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
@ -485,8 +486,8 @@ class TestNixlHandshake:
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_bytes = 4096
|
||||
worker.block_len = worker.slot_size_bytes * worker.block_size
|
||||
worker.slot_size_per_layer = [4096]
|
||||
worker.block_len_per_layer = [4096 * worker.block_size]
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
@ -498,7 +499,7 @@ class TestNixlHandshake:
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=worker.block_len,
|
||||
block_lens=worker.block_len_per_layer,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
"target_attn_1": mock.MagicMock(),
|
||||
"target_attn_2": mock.MagicMock()
|
||||
}
|
||||
target_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
# Draft model has one extra attention layer compared to target model
|
||||
all_attn_layers = {
|
||||
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
|
||||
}
|
||||
|
||||
all_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
|
||||
# Make mock_get_layers return different values for each call
|
||||
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers, target_indx_layers, all_attn_layers,
|
||||
all_indx_layers
|
||||
]
|
||||
|
||||
# Setup mock for pp group to return the appropriate value for world size
|
||||
mock_pp_group = mock.MagicMock()
|
||||
@ -658,6 +664,9 @@ def test_propose_tree(spec_token_tree):
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
attn_metadata_builder
|
||||
]
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
|
201
tests/v1/spec_decode/test_mtp.py
Normal file
201
tests/v1/spec_decode/test_mtp.py
Normal file
@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
|
||||
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
|
||||
|
||||
|
||||
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||
"""Create an MTP proposer with unified model configuration."""
|
||||
model_config = ModelConfig(model=mimo_7b_dir,
|
||||
runner="generate",
|
||||
max_model_len=100,
|
||||
trust_remote_code=True)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=mimo_7b_dir,
|
||||
method="mtp",
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig())
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config,
|
||||
device=current_platform.device_type)
|
||||
|
||||
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
|
||||
mock_get_pp_group):
|
||||
"""Test MTP-specific model loading with unified model approach."""
|
||||
|
||||
# Setup mocks
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||
target_indexer_layers: dict = {}
|
||||
all_indexer_layers: dict = {}
|
||||
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers, target_indexer_layers, all_attn_layers,
|
||||
all_indexer_layers
|
||||
]
|
||||
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = 1
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Create target model
|
||||
class _TargetModelStub(LlamaForCausalLM):
|
||||
model: mock.MagicMock
|
||||
lm_head: mock.MagicMock
|
||||
|
||||
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||
target_model.model = mock.MagicMock()
|
||||
target_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create MTP proposer
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens=4)
|
||||
proposer.load_model(target_model)
|
||||
|
||||
# Verify MTP-specific behavior:
|
||||
# Model is loaded
|
||||
mock_get_model.assert_called_once()
|
||||
# MTP shares lm_head with target model
|
||||
assert proposer.model.lm_head == target_model.lm_head
|
||||
# MTP shares embed_tokens with target model
|
||||
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||
def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
"""Test that MTP's forward method returns hidden states directly"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
batch_size = 2
|
||||
seq_lens = [5, 3]
|
||||
total_tokens = sum(seq_lens)
|
||||
vocab_size = 100
|
||||
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens)
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Mock the MTP model to verify it returns hidden states directly
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# MTP returns hidden states directly
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.return_value = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
else:
|
||||
# Multiple forward passes for multi-token speculation
|
||||
forward_returns = []
|
||||
for i in range(num_speculative_tokens):
|
||||
if i == 0:
|
||||
h_states = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
else:
|
||||
h_states = torch.zeros(batch_size, hidden_size, device=device)
|
||||
forward_returns.append(h_states)
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Mock compute_logits
|
||||
def create_deterministic_logits(batch_size, vocab_size, token_offset):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
logits[:, token_offset] = 100.0
|
||||
return logits
|
||||
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.compute_logits.return_value = create_deterministic_logits(
|
||||
batch_size, vocab_size, 42)
|
||||
else:
|
||||
logits_returns = [
|
||||
create_deterministic_logits(batch_size, vocab_size, 42 + i)
|
||||
for i in range(num_speculative_tokens)
|
||||
]
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
proposer.model = model_mock
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Prepare inputs
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||
common_attn_metadata = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
target_token_ids = torch.randint(0,
|
||||
vocab_size, (total_tokens, ),
|
||||
device=device)
|
||||
target_positions = torch.cat([
|
||||
torch.arange(seq_lens[0], device=device),
|
||||
torch.arange(seq_lens[1], device=device)
|
||||
])
|
||||
target_hidden_states = torch.randn(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
next_token_ids = torch.randint(0,
|
||||
vocab_size, (batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Setup attention metadata
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.attn_metadata_builder = attn_metadata_builder
|
||||
|
||||
# Run propose
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
# Verify the model was called correctly
|
||||
assert model_mock.called
|
||||
# Verify output shape
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
|
@ -1678,6 +1678,15 @@ def cp_gather_cache(src_cache: torch.Tensor,
|
||||
cu_seq_lens, batch_size, seq_starts)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
kv_cache_dtype: str) -> None:
|
||||
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
|
||||
quant_block_size,
|
||||
kv_cache_dtype)
|
||||
|
||||
|
||||
def get_device_attribute(attribute: int, device: int) -> int:
|
||||
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
||||
|
||||
|
@ -70,6 +70,7 @@ class AttentionBackend(ABC):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -95,6 +95,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sparse: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -155,6 +156,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = use_sparse
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -187,7 +189,8 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink)
|
||||
has_sink=self.has_sink,
|
||||
use_sparse=use_sparse)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from typing import ClassVar, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||
subclass_attention_backend)
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches, subclass_attention_backend)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
@ -28,6 +28,8 @@ def create_chunked_local_attention_backend(
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
|
@ -138,3 +138,208 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pack_seq_kernel(
|
||||
x_ptr, # [N, D]
|
||||
out_ptr, # [B, Lmax, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
N: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
PAD_VALUE: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# Compute start index and sequence length from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
|
||||
# compute input row indices for valid (b, t)
|
||||
in_row = in_start + off_t
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# Pointers
|
||||
# x_ptr: row-major [N, D]
|
||||
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [B, Lmax, D]
|
||||
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
|
||||
None] * D + off_d[None, :]
|
||||
|
||||
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||
d_mask = off_d[None, :] < D
|
||||
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||
|
||||
# Load & write only where within seq_len
|
||||
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def pack_seq_triton(x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pad_value: float = -float('inf'),
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Pack sequences of different lengths into a batched tensor.
|
||||
|
||||
Args:
|
||||
x: [N, ...] - input tensor where N is total number of tokens
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
pad_value: value to use for padding
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
packed: [B, Lmax, ...] - packed tensor
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||
original_shape = x.shape
|
||||
if len(original_shape) > 2:
|
||||
N = original_shape[0]
|
||||
x_reshaped = x.reshape(N, -1)
|
||||
D = x_reshaped.shape[1]
|
||||
else:
|
||||
N, D = x.shape
|
||||
x_reshaped = x
|
||||
|
||||
B = lengths.numel()
|
||||
Lmax = int(lengths.max().item())
|
||||
|
||||
# Starts are computed inside the kernel from lengths
|
||||
|
||||
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_pack_seq_kernel[grid](x_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
N,
|
||||
D,
|
||||
Lmax,
|
||||
PAD_VALUE=float(pad_value),
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 2:
|
||||
output_shape = (B, Lmax) + original_shape[1:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _unpack_seq_triton_kernel(
|
||||
packed_ptr, # [B, Lmax, D]
|
||||
out_ptr, # [N, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
B: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# bounds: compute start from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# compute output row indices for valid (b, t)
|
||||
out_row = in_start + off_t
|
||||
|
||||
# Pointers
|
||||
# packed_ptr: row-major [B, Lmax, D]
|
||||
packed_row_ptr = packed_ptr + (pid_b * Lmax +
|
||||
off_t)[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [N, D]
|
||||
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# Load from packed tensor and store to output
|
||||
d_mask = off_d[None, :] < D
|
||||
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def unpack_seq_triton(packed_tensor: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Unpack a packed decode query tensor back to the original format.
|
||||
Efficient Triton implementation.
|
||||
|
||||
Args:
|
||||
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||
original_shape = packed_tensor.shape
|
||||
if len(original_shape) > 3:
|
||||
B, Lmax = original_shape[:2]
|
||||
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||
D = packed_reshaped.shape[2]
|
||||
else:
|
||||
B, Lmax, D = packed_tensor.shape
|
||||
packed_reshaped = packed_tensor
|
||||
|
||||
# Calculate total number of elements
|
||||
N = int(lengths.sum().item())
|
||||
|
||||
out = torch.empty((N, D),
|
||||
device=packed_tensor.device,
|
||||
dtype=packed_tensor.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_unpack_seq_triton_kernel[grid](packed_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
B,
|
||||
Lmax,
|
||||
D,
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 3:
|
||||
output_shape = (N, ) + original_shape[2:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
@ -19,6 +19,15 @@ if current_platform.is_cuda():
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
@ -37,24 +46,34 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: Optional[int] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||
- num_q_tokens_per_head_k:
|
||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
- num_heads_k: The number of k heads.
|
||||
- num_heads_q:
|
||||
The number of q heads.
|
||||
This argument is optional when sparse attention is not enabled
|
||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
- topk: If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array
|
||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
Returns:
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
|
||||
is_fp8_kvcache, topk)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
@ -69,45 +88,95 @@ def flash_mla_with_kvcache(
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
descale_q: (batch_size), torch.float32. Descaling factors for Q.
|
||||
descale_k: (batch_size), torch.float32. Descaling factors for K.
|
||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
- cache_seqlens: (batch_size), torch.int32.
|
||||
- head_dim_v: Head dimension of v.
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||
returned by get_mla_metadata.
|
||||
- num_splits:
|
||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
- softmax_scale: float.
|
||||
The scale of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
- causal: bool. Whether to apply causal attention mask.
|
||||
- descale_q: (batch_size),
|
||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||
- descale_k: (batch_size),
|
||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||
- is_fp8_kvcache: bool.
|
||||
Whether the k_cache and v_cache are in fp8 format.
|
||||
For the format of FP8 KV cache, please refer to README.md
|
||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||
If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array will be attended to.
|
||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
For details about how to set up `indices`, please refer to README.md.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
Returns:
|
||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
descale_q,
|
||||
descale_k,
|
||||
)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
assert (descale_q is None) == (
|
||||
descale_k is None
|
||||
), "descale_q and descale_k should be both None or both not None"
|
||||
|
||||
# Note(hc): need revisit when we support DCP with decode query_len > 1.
|
||||
return out.squeeze(1), softmax_lse.squeeze(-1)
|
||||
if indices is None and q.element_size() == 1:
|
||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
|
||||
indices)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
|
||||
sm_scale, d_v)
|
||||
return results
|
||||
|
||||
|
||||
#
|
||||
|
@ -50,6 +50,7 @@ class PagedAttention:
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
|
@ -144,6 +144,7 @@ def get_attn_backend(
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
@ -158,6 +159,7 @@ def get_attn_backend(
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
|
||||
|
||||
@ -170,6 +172,7 @@ def _cached_get_attn_backend(
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
@ -203,7 +206,7 @@ def _cached_get_attn_backend(
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla, has_sink)
|
||||
use_mla, has_sink, use_sparse)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
|
@ -22,7 +22,8 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2",
|
||||
"fp8_inc"]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
|
||||
@ -52,7 +53,11 @@ class CacheConfig:
|
||||
cache_dtype: CacheDType = "auto"
|
||||
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
|
||||
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
|
||||
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
|
||||
bfloat16 instead, this is an invalid option for models that do not default
|
||||
to fp8.
|
||||
"""
|
||||
is_attention_free: bool = False
|
||||
"""Whether the model is attention-free. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
@ -171,11 +176,12 @@ class CacheConfig:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype in get_args(CacheDType):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor.")
|
||||
if self.cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor.")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
|
||||
|
@ -360,6 +360,7 @@ class CompilationConfig:
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.sparse_attn_indexer",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
|
@ -1077,14 +1077,14 @@ class ModelConfig:
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp',
|
||||
'kimi_k2', 'longcat_flash'):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == 'eagle':
|
||||
# if the model is an EAGLE module, check for the
|
||||
# underlying architecture
|
||||
return self.hf_text_config.model.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3') \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
|
||||
and self.hf_text_config.kv_lora_rank is not None
|
||||
return False
|
||||
|
||||
|
@ -32,14 +32,17 @@ logger = init_logger(__name__)
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
|
||||
"longcat_flash_mtp"]
|
||||
"longcat_flash_mtp", "mtp"]
|
||||
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding."""
|
||||
|
||||
enforce_eager: Optional[bool] = None
|
||||
"""Override the default enforce_eager from model_config"""
|
||||
# General speculative decoding control
|
||||
num_speculative_tokens: SkipValidation[int] = None # type: ignore
|
||||
"""The number of speculative tokens, if provided. It will default to the
|
||||
@ -143,7 +146,7 @@ class SpeculativeConfig:
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type == "deepseek_v3":
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
@ -207,11 +210,21 @@ class SpeculativeConfig:
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in MTP_MODEL_TYPES:
|
||||
logger.warning("method `%s` is deprecated and replaced with mtp.",
|
||||
self.method)
|
||||
self.method = "mtp"
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
if (self.target_model_config
|
||||
and self.target_model_config.hf_text_config.model_type
|
||||
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
|
||||
if self.method == "mtp":
|
||||
assert (
|
||||
self.target_model_config
|
||||
is not None), "target_model_config must be present for mtp"
|
||||
if self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v32":
|
||||
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
|
||||
# remove this when the issue is fixed.
|
||||
self.enforce_eager = True
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
@ -314,31 +327,13 @@ class SpeculativeConfig:
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
in MTP_MODEL_TYPES):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"qwen3_next_mtp"):
|
||||
self.method = "qwen3_next_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Qwen3Next MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
"Enabling num_speculative_tokens > 1 will run" \
|
||||
"multiple times of forward on same MTP layer" \
|
||||
",which may result in lower acceptance rate" \
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
@ -355,7 +350,7 @@ class SpeculativeConfig:
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
"eagle, or mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
@ -564,8 +559,7 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
|
@ -54,6 +54,7 @@ class HTTPConnection:
|
||||
stream: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
allow_redirects: bool = True,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
@ -63,7 +64,8 @@ class HTTPConnection:
|
||||
return client.get(url,
|
||||
headers=self._headers(**extra_headers),
|
||||
stream=stream,
|
||||
timeout=timeout)
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects)
|
||||
|
||||
async def get_async_response(
|
||||
self,
|
||||
@ -71,6 +73,7 @@ class HTTPConnection:
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
allow_redirects: bool = True,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
@ -79,10 +82,17 @@ class HTTPConnection:
|
||||
|
||||
return client.get(url,
|
||||
headers=self._headers(**extra_headers),
|
||||
timeout=timeout)
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects)
|
||||
|
||||
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
def get_bytes(self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
allow_redirects: bool = True) -> bytes:
|
||||
with self.get_response(url,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.content
|
||||
@ -92,8 +102,10 @@ class HTTPConnection:
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
allow_redirects: bool = True,
|
||||
) -> bytes:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
async with await self.get_async_response(
|
||||
url, timeout=timeout, allow_redirects=allow_redirects) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.read()
|
||||
|
@ -84,7 +84,7 @@ class NixlAgentMetadata(
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
block_lens: list[int]
|
||||
attn_backend_name: str
|
||||
kv_cache_layout: str
|
||||
|
||||
@ -766,6 +766,9 @@ class NixlConnectorWorker:
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas
|
||||
or self._use_flashinfer)
|
||||
tensor_size_bytes = None
|
||||
# Enable different block lengths for different layers when MLA is used.
|
||||
self.block_len_per_layer = list[int]()
|
||||
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [
|
||||
cache_or_caches
|
||||
@ -783,10 +786,25 @@ class NixlConnectorWorker:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
assert cache.shape[0] == self.num_blocks, \
|
||||
"All kv cache tensors must have the same number of blocks"
|
||||
|
||||
self.block_len_per_layer.append(curr_tensor_size_bytes //
|
||||
self.num_blocks)
|
||||
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
|
||||
self.block_size)
|
||||
|
||||
if not self.use_mla:
|
||||
# Different kv cache shape is not supported by HeteroTP
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
caches_data.append(
|
||||
(base_addr, tensor_size_bytes, self.tp_rank, ""))
|
||||
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
|
||||
|
||||
logger.debug("Different block lengths collected: %s",
|
||||
set(self.block_len_per_layer))
|
||||
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
||||
assert self.num_blocks != 0
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
||||
self.num_regions = len(caches_data)
|
||||
@ -799,16 +817,12 @@ class NixlConnectorWorker:
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.slot_size_bytes = self.block_len // self.block_size
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
if self._use_flashinfer:
|
||||
assert self.slot_size_bytes % 2 == 0
|
||||
self.slot_size_bytes /= 2
|
||||
for i in range(len(self.slot_size_per_layer)):
|
||||
assert self.slot_size_per_layer[i] % 2 == 0
|
||||
self.slot_size_per_layer[i] //= 2
|
||||
|
||||
# NOTE (NickLucche) When FlashInfer is used, memory is registered
|
||||
# with joint KV for each block. This minimizes the overhead in
|
||||
@ -818,17 +832,17 @@ class NixlConnectorWorker:
|
||||
# of 'virtual' regions here and halve `block_len` below.
|
||||
self.num_regions *= 2
|
||||
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in seen_base_addresses:
|
||||
for i, base_addr in enumerate(seen_base_addresses):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
# select agent_meta.num_blocks instead of self.num_blocks for
|
||||
# local descr, and that makes handling regular flow less clean.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
@ -838,7 +852,7 @@ class NixlConnectorWorker:
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
@ -878,7 +892,7 @@ class NixlConnectorWorker:
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
block_len=self.block_len,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout)
|
||||
ready_event = threading.Event()
|
||||
@ -903,7 +917,7 @@ class NixlConnectorWorker:
|
||||
The latter, assuming D.world_size > P.world_size, requires that two or
|
||||
more local TP worker share the xfer from a single TP worker.
|
||||
|
||||
Here's an example:
|
||||
Here's an example (non-MLA case):
|
||||
|
||||
rank_offset p_remote_tp_rank
|
||||
(kv split no)
|
||||
@ -959,14 +973,20 @@ class NixlConnectorWorker:
|
||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if self.use_mla or is_kv_replicated:
|
||||
# With MLA the only difference is in the number of blocks.
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes)
|
||||
assert self.block_len == nixl_agent_meta.block_len
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
|
||||
"KV cache sizes must match between P and D when replicated"
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0])
|
||||
else:
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes * tp_ratio)
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
assert block_len == remote_block_len, \
|
||||
"All remote layers must have the same block size"
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio)
|
||||
if self._use_flashinfer:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
@ -977,14 +997,14 @@ class NixlConnectorWorker:
|
||||
raise ValueError(
|
||||
"Heterogeneous TP is not supported on XPU")
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different block size is not supported "
|
||||
f"{self.block_size=} {remote_block_size=}")
|
||||
"Remote P worker with different page/block size is not supported "
|
||||
f"{self.block_size=}, {remote_block_size=}")
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
||||
if engine_id in self.dst_num_blocks:
|
||||
@ -999,13 +1019,16 @@ class NixlConnectorWorker:
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
|
||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(
|
||||
self.block_len_per_layer)
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
@ -1016,9 +1039,9 @@ class NixlConnectorWorker:
|
||||
if self._use_flashinfer:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_len // 2
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
||||
|
||||
logger.debug(
|
||||
@ -1345,7 +1368,7 @@ class NixlConnectorWorker:
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
|
||||
def get_backend_aware_kv_block_len(self):
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
@ -1356,9 +1379,9 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
if self._use_flashinfer:
|
||||
# For indexing only half (either just the K or V part).
|
||||
block_len = self.block_len // 2
|
||||
block_len = self.block_len_per_layer[layer_idx] // 2
|
||||
else:
|
||||
block_len = self.block_len
|
||||
block_len = self.block_len_per_layer[layer_idx]
|
||||
return block_len
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
|
||||
|
@ -1486,7 +1486,7 @@ class EngineArgs:
|
||||
raise NotImplementedError(
|
||||
"Draft model speculative decoding is not supported yet. "
|
||||
"Please consider using other speculative decoding methods "
|
||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||
"such as ngram, medusa, eagle, or mtp.")
|
||||
|
||||
V1_BACKENDS = [
|
||||
"FLASH_ATTN",
|
||||
|
@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
|
||||
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||
@ -725,6 +726,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||
|
||||
# Whether to allow HTTP redirects when fetching from media URLs.
|
||||
# Default to True
|
||||
"VLLM_MEDIA_URL_ALLOW_REDIRECTS":
|
||||
lambda: bool(int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1"))),
|
||||
|
||||
# Max number of workers for the thread pool handling
|
||||
# media bytes loading. Set to 1 to disable parallel processing.
|
||||
# Default is 8
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@ -375,3 +376,20 @@ class PolyNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"""
|
||||
Layer Normalization.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
|
||||
self.eps).type_as(x)
|
||||
|
@ -24,6 +24,9 @@ class MLAModules:
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_b_proj: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
indexer: Optional[torch.nn.Module]
|
||||
is_sparse: bool
|
||||
topk_indices_buffer: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@CustomOp.register("multi_head_latent_attention")
|
||||
@ -76,6 +79,13 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
self.kv_b_proj = mla_modules.kv_b_proj
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.o_proj = mla_modules.o_proj
|
||||
self.indexer = mla_modules.indexer
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
|
||||
if self.indexer is not None:
|
||||
assert hasattr(self.indexer, "topk_tokens")
|
||||
self.topk_tokens = self.indexer.topk_tokens
|
||||
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
@ -92,6 +102,7 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
@ -100,6 +111,7 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
indexer=self.indexer,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
@ -145,6 +157,10 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
||||
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions,
|
||||
self.rotary_emb)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_deep_gemm_e8m0_used():
|
||||
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
|
||||
layer.orig_dtype, layer.weight)
|
||||
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(layer.weight.data,
|
||||
layer.weight_scale.data, block_sz)
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
elif (current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
|
||||
layer.weight)):
|
||||
and cutlass_block_fp8_supported and not should_use_deepgemm):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
|
||||
|
@ -346,8 +346,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=model_config.use_mla).page_size_bytes
|
||||
dtype=kv_cache_dtype).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
@ -401,6 +400,31 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
"exactly equal.", mamba_padding_pct)
|
||||
|
||||
|
||||
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
|
||||
"""
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
|
||||
is_v32 = hasattr(hf_config, "index_topk")
|
||||
assert is_v32
|
||||
|
||||
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
|
||||
# "auto")
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.cache_dtype == "auto" or \
|
||||
cache_config.cache_dtype.startswith("fp8"):
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
|
||||
if cache_config.cache_dtype == "bfloat16":
|
||||
cache_config.cache_dtype = "auto"
|
||||
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
||||
|
||||
|
||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
@ -417,4 +441,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"MambaForCausalLM": MambaModelConfig,
|
||||
"Mamba2ForCausalLM": MambaModelConfig,
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
||||
}
|
||||
|
@ -53,8 +53,20 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -33,15 +33,21 @@ from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
@ -49,6 +55,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -56,13 +64,26 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
|
||||
DeepseekV32IndexerMetadata)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
||||
@ -276,6 +297,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
@ -289,6 +311,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -306,6 +329,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
assert topk_indices_buffer is None, "topk_indices_buffer is not \
|
||||
supported for DeepseekV2Attention"
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
@ -418,6 +443,390 @@ class DeepseekV2Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
|
||||
cache_config: CacheConfig):
|
||||
super().__init__()
|
||||
self.kv_cache = [torch.tensor([])]
|
||||
self.head_dim = head_dim
|
||||
self.prefix = prefix
|
||||
self.cache_config = cache_config
|
||||
self.dtype = dtype
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return MLAAttentionSpec( # Only has one vector instead of K + V
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
...
|
||||
|
||||
def get_attn_backend(self) -> AttentionBackend:
|
||||
return DeepseekV32IndexerBackend
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def cp_gather_indexer_k_quant_cache(
|
||||
kv_cache, # [num_blocks, block_size, head_dim + 1]
|
||||
dst_value, # [cu_seq_lens[-1], head_dim]
|
||||
dst_scale, # [cu_seq_lens[-1], 4]
|
||||
block_table, # [batch_size, num_blocks]
|
||||
cu_seq_lens, # [batch_size + 1, ]
|
||||
batch_size,
|
||||
):
|
||||
num_blocks, block_size, _ = kv_cache.shape
|
||||
head_dim = dst_value.shape[-1]
|
||||
kv_cache = kv_cache.view(num_blocks, -1)
|
||||
|
||||
expected_value = []
|
||||
expected_scale = []
|
||||
for b in range(batch_size):
|
||||
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = cdiv(s, block_size)
|
||||
blocks = block_table[b, :tot]
|
||||
|
||||
value = []
|
||||
scale = []
|
||||
full_block = torch.arange(tot - 1,
|
||||
device=kv_cache.device,
|
||||
dtype=torch.int32)
|
||||
non_remaining_value = kv_cache[blocks[full_block], :block_size *
|
||||
head_dim].view(-1, head_dim)
|
||||
non_remaining_scale = kv_cache[blocks[full_block],
|
||||
block_size * head_dim:].view(-1, 4)
|
||||
|
||||
remaining = s - (tot - 1) * block_size
|
||||
|
||||
value = torch.cat([
|
||||
non_remaining_value,
|
||||
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
|
||||
],
|
||||
dim=0)
|
||||
scale = torch.cat([
|
||||
non_remaining_scale,
|
||||
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
|
||||
remaining * 4].view(-1, 4)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
expected_value.append(value)
|
||||
expected_scale.append(scale)
|
||||
|
||||
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
||||
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
||||
gather_value = gather_value.view(torch.float8_e4m3fn)
|
||||
gather_scale = gather_scale.view(torch.float32)
|
||||
dst_value.copy_(gather_value)
|
||||
dst_scale.copy_(gather_scale)
|
||||
|
||||
|
||||
def sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
k_scale = torch.empty([chunk.total_seq_lens, 1],
|
||||
device=k.device,
|
||||
dtype=torch.float32)
|
||||
cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
chunk.num_reqs,
|
||||
)
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start:chunk.token_end],
|
||||
(k_fp8, k_scale),
|
||||
weights[chunk.token_start:chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
||||
dim=-1)[1]
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||
chunk.cu_seqlen_ks)[:, None] < 0
|
||||
mask = torch.full_like(topk_indices,
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=topk_indices.device)
|
||||
mask = mask_lo & mask_hi
|
||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||
topk_indices_buffer[
|
||||
chunk.token_start:chunk.token_end, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:])
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
# padded query len
|
||||
current_device = padded_q_fp8_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
positions = torch.arange(max_model_len,
|
||||
device=current_device).unsqueeze(0).expand(
|
||||
batch_size * next_n, -1)
|
||||
row_indices = torch.arange(padded_num_tokens,
|
||||
device=current_device) // next_n
|
||||
next_n_offset = torch.arange(
|
||||
padded_num_tokens,
|
||||
device=padded_q_fp8_decode_tokens.device) % next_n
|
||||
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
# index_end_pos: [B * N, 1]
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
topk_indices = logits.topk(topk_tokens,
|
||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
topk_indices[topk_indices > index_end_pos] = -1
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens)
|
||||
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8)
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(
|
||||
torch.float8_e4m3fn).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sparse_attn_indexer",
|
||||
op_func=sparse_attn_indexer,
|
||||
mutates_args=["topk_indices_buffer"],
|
||||
fake_impl=sparse_attn_indexer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
q_lora_rank: int,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
|
||||
self.topk_tokens = config.index_topk
|
||||
self.n_head = config.index_n_heads # 64
|
||||
self.head_dim = config.index_head_dim # 128
|
||||
self.rope_dim = config.qk_rope_head_dim # 64
|
||||
self.q_lora_rank = q_lora_rank # 1536
|
||||
# no tensor parallel, just replicated
|
||||
self.wq_b = ReplicatedLinear(self.q_lora_rank,
|
||||
self.head_dim * self.n_head,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq_b")
|
||||
self.wk = ReplicatedLinear(hidden_size,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk")
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.weights_proj = ReplicatedLinear(hidden_size,
|
||||
self.n_head,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj")
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
self.scale_fmt = "ue8m0"
|
||||
self.quant_block_size = 128 # TODO: get from config
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
# NOTE: (zyongye) we use fp8 naive cache,
|
||||
# where we store value in fp8 and scale in fp32
|
||||
# per self.quant_block_size element
|
||||
self.k_cache = DeepseekV32IndexerCache(
|
||||
head_dim=self.head_dim +
|
||||
self.head_dim // self.quant_block_size * 4,
|
||||
dtype=torch.uint8,
|
||||
prefix=f"{prefix}.k_cache",
|
||||
cache_config=cache_config)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.prefix = prefix
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
get_max_prefill_buffer_size)
|
||||
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
|
||||
rotary_emb) -> torch.Tensor:
|
||||
q, _ = self.wq_b(qr)
|
||||
q = q.view(-1, self.n_head, self.head_dim)
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
k, _ = self.wk(hidden_states)
|
||||
k = self.k_norm(k)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(q,
|
||||
self.quant_block_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=self.scale_fmt
|
||||
is not None)
|
||||
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
weights = weights.unsqueeze(
|
||||
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return torch.ops.vllm.sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MLAAttention(nn.Module):
|
||||
"""
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
@ -429,6 +838,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
@ -443,6 +853,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -523,6 +934,15 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
|
||||
if self.is_v32:
|
||||
self.indexer = Indexer(vllm_config, config, hidden_size,
|
||||
q_lora_rank, quant_config, cache_config,
|
||||
topk_indices_buffer, f"{prefix}.indexer")
|
||||
else:
|
||||
self.indexer = None
|
||||
|
||||
mla_modules = MLAModules(
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
@ -536,7 +956,11 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else None,
|
||||
indexer=self.indexer,
|
||||
is_sparse=self.is_v32,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
self.mla_attn = MultiHeadLatentAttention(
|
||||
self.hidden_size,
|
||||
self.num_local_heads,
|
||||
@ -562,7 +986,10 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -585,6 +1012,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -600,6 +1028,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
@ -683,6 +1112,16 @@ class DeepseekV2Model(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@ -695,7 +1134,8 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
|
@ -308,6 +308,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: FlashConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
@ -329,6 +330,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
# Dual attention structure
|
||||
self.self_attn = nn.ModuleList([
|
||||
DeepseekV2MLAAttention(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -454,6 +456,7 @@ class FlashModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: FlashDecoderLayer(
|
||||
vllm_config,
|
||||
config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
|
@ -66,7 +66,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import is_list_of
|
||||
@ -335,14 +335,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now.")
|
||||
if current_platform.is_device_capability(
|
||||
100) and self.attn_backend != _Backend.TORCH_SDPA:
|
||||
# TODO(Roger/Wentao): remove this after FA
|
||||
# or XFORMERS's issue fixed on Blackwell
|
||||
logger.info_once("Qwen3-VL vision attention does not support "
|
||||
f"{self.attn_backend} backend on Blackwell now. "
|
||||
"Vision attention backend is set to TORCH_SDPA.")
|
||||
self.attn_backend = _Backend.TORCH_SDPA
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen3_VisionBlock(
|
||||
@ -1134,14 +1126,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if not multimodal_config.get_limit_per_prompt("image") and \
|
||||
not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
@ -1157,11 +1152,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config.vision_config.deepstack_visual_indexes
|
||||
) if self.use_deepstack else 0
|
||||
# register buffer for deepstack
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
] if self.use_deepstack else None
|
||||
if self.use_deepstack and self.visual is not None:
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
]
|
||||
else:
|
||||
self.deepstack_input_embeds = None
|
||||
self.visual_dim = config.vision_config.out_hidden_size
|
||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||
|
||||
@ -1596,7 +1595,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
skip_prefixes = []
|
||||
if self.visual is None:
|
||||
skip_prefixes.extend(["visual."])
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
|
@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if not multimodal_config.get_limit_per_prompt("image") and \
|
||||
not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
config.vision_config.deepstack_visual_indexes
|
||||
) if self.use_deepstack else 0
|
||||
# register buffer for deepstack
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
] if self.use_deepstack else None
|
||||
if self.use_deepstack and self.visual is not None:
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
]
|
||||
else:
|
||||
self.deepstack_input_embeds = None
|
||||
self.visual_dim = config.vision_config.out_hidden_size
|
||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||
|
@ -70,6 +70,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
||||
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
|
||||
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
|
||||
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
|
||||
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
|
||||
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
|
||||
|
@ -140,7 +140,11 @@ class MediaConnector:
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(url, timeout=fetch_timeout)
|
||||
data = connection.get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
@ -167,7 +171,11 @@ class MediaConnector:
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||
data = await connection.async_get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
media_io.load_bytes, data)
|
||||
return await future
|
||||
|
@ -93,11 +93,14 @@ class CpuPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on CPU.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
|
@ -129,6 +129,8 @@ class CudaPlatformBase(Platform):
|
||||
# TODO(lucas): handle this more gracefully
|
||||
# Note: model_config may be None during testing
|
||||
if model_config is not None and model_config.use_mla:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
|
||||
# then we default to FlashMLA backend for non-blackwell GPUs,
|
||||
# else we default to CutlassMLA. For each case, we force the
|
||||
@ -175,6 +177,12 @@ class CudaPlatformBase(Platform):
|
||||
"Forcing kv cache block size to 64 for FlashInferMLA "
|
||||
"backend.")
|
||||
|
||||
# TODO(Chen): remove this hacky code
|
||||
if use_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse "
|
||||
"backend.")
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
@ -205,6 +213,12 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
|
||||
# For Blackwell GPUs, force TORCH_SDPA for now.
|
||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
||||
if cls.has_device_capability(100):
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
return _Backend.XFORMERS
|
||||
|
||||
@ -225,7 +239,7 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
@ -235,6 +249,11 @@ class CudaPlatformBase(Platform):
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
|
||||
if use_sparse:
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla.flashmla_sparse."
|
||||
"FlashMLASparseBackend")
|
||||
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None and cls.is_device_capability(100)
|
||||
and block_size == 128)
|
||||
|
@ -194,7 +194,7 @@ class Platform:
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
|
@ -195,7 +195,10 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on ROCm.")
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
|
@ -49,7 +49,10 @@ class TpuPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink) -> str:
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
|
@ -36,7 +36,10 @@ class XPUPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
has_sink: bool, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on XPU.")
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
if not use_v1:
|
||||
raise ValueError("XPU backend only supports V1.")
|
||||
|
@ -66,6 +66,8 @@ class LazyConfigDict(dict):
|
||||
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
chatglm="ChatGLMConfig",
|
||||
deepseek_vl_v2="DeepseekVLV2Config",
|
||||
deepseek_v3="DeepseekV3Config",
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
kimi_vl="KimiVLConfig",
|
||||
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
|
||||
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
|
||||
|
@ -8,6 +8,7 @@ Model configs may be defined in this directory for the following reasons:
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||
from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
|
||||
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
@ -37,6 +38,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
__all__ = [
|
||||
"ChatGLMConfig",
|
||||
"DeepseekVLV2Config",
|
||||
"DeepseekV3Config",
|
||||
"DotsOCRConfig",
|
||||
"EAGLEConfig",
|
||||
"RWConfig",
|
||||
|
101
vllm/transformers_utils/configs/deepseek_v3.py
Normal file
101
vllm/transformers_utils/configs/deepseek_v3.py
Normal file
@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV3Config(PretrainedConfig):
|
||||
|
||||
model_type = "deepseek_v3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=129280,
|
||||
hidden_size=7168,
|
||||
intermediate_size=18432,
|
||||
moe_intermediate_size=2048,
|
||||
num_hidden_layers=61,
|
||||
num_nextn_predict_layers=1,
|
||||
num_attention_heads=128,
|
||||
num_key_value_heads=128,
|
||||
n_shared_experts=1,
|
||||
n_routed_experts=256,
|
||||
ep_size=1,
|
||||
routed_scaling_factor=2.5,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method='noaux_tc',
|
||||
n_group=8,
|
||||
topk_group=4,
|
||||
num_experts_per_tok=8,
|
||||
moe_layer_freq=1,
|
||||
first_k_dense_replace=3,
|
||||
norm_topk_prob=True,
|
||||
scoring_func='sigmoid',
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
@ -130,6 +130,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"fp8_e5m2": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"fp8_inc": torch.float8_e4m3fn,
|
||||
"fp8_ds_mla": torch.uint8,
|
||||
}
|
||||
|
||||
TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
@ -3433,6 +3434,12 @@ def has_triton_kernels() -> bool:
|
||||
return _has_module("triton_kernels")
|
||||
|
||||
|
||||
def has_tilelang() -> bool:
|
||||
"""Whether the optional `tilelang` package is available."""
|
||||
|
||||
return _has_module("tilelang")
|
||||
|
||||
|
||||
def set_process_title(name: str,
|
||||
suffix: str = "",
|
||||
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None:
|
||||
|
@ -70,17 +70,25 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl,\
|
||||
_get_mn_major_tma_aligned_tensor_impl
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||
global _get_paged_mqa_logits_metadata_impl
|
||||
global _get_mn_major_tma_aligned_tensor_impl
|
||||
|
||||
# fast path
|
||||
if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
|
||||
or _grouped_masked_impl is not None):
|
||||
or _grouped_masked_impl is not None
|
||||
or _fp8_mqa_logits_impl is not None
|
||||
or _fp8_paged_mqa_logits_impl is not None
|
||||
or _get_paged_mqa_logits_metadata_impl is not None):
|
||||
return
|
||||
|
||||
if not has_deep_gemm():
|
||||
@ -97,10 +105,20 @@ def _lazy_init() -> None:
|
||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
||||
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
||||
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
||||
_get_paged_mqa_logits_metadata_impl = getattr(
|
||||
_dg, "get_paged_mqa_logits_metadata", None)
|
||||
_get_mn_major_tma_aligned_tensor_impl = getattr(
|
||||
_dg, "get_mn_major_tma_aligned_tensor", None)
|
||||
|
||||
|
||||
def get_num_sms() -> int:
|
||||
_lazy_init()
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
return int(_dg.get_num_sms())
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
|
||||
_lazy_init()
|
||||
@ -135,6 +153,100 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs)
|
||||
|
||||
|
||||
def fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
def get_paged_mqa_logits_metadata(context_lens: torch.Tensor, block_size: int,
|
||||
num_sms: int) -> torch.Tensor:
|
||||
"""Build scheduling metadata for paged MQA logits.
|
||||
|
||||
Args:
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
per batch element.
|
||||
block_size: KV-cache block size in tokens (e.g., 64).
|
||||
num_sms: Number of SMs available. 132 for Hopper
|
||||
|
||||
Returns:
|
||||
Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
|
||||
schedule work across SMs.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _get_paged_mqa_logits_metadata_impl is None:
|
||||
return _missing()
|
||||
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size,
|
||||
num_sms)
|
||||
|
||||
|
||||
def fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
_lazy_init()
|
||||
if _fp8_paged_mqa_logits_impl is None:
|
||||
return _missing()
|
||||
return _fp8_paged_mqa_logits_impl(q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
clean_logits=True)
|
||||
|
||||
|
||||
def _ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
@ -195,9 +307,13 @@ __all__ = [
|
||||
"fp8_gemm_nt",
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
"fp8_mqa_logits",
|
||||
"fp8_paged_mqa_logits",
|
||||
"get_paged_mqa_logits_metadata",
|
||||
"per_block_cast_to_fp8",
|
||||
"is_deep_gemm_e8m0_used",
|
||||
"is_deep_gemm_supported",
|
||||
"get_num_sms",
|
||||
"should_use_deepgemm_for_fp8_linear",
|
||||
"get_col_major_tma_aligned_tensor",
|
||||
]
|
||||
]
|
||||
|
@ -74,6 +74,7 @@ class TorchSDPABackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return _get_paged_attn_impl().get_kv_cache_shape(
|
||||
num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
@ -80,6 +80,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
@ -29,7 +29,6 @@ from vllm.utils.flashinfer import (can_use_trtllm_attention,
|
||||
flashinfer_disable_q_quantization,
|
||||
supports_trtllm_attention,
|
||||
use_trtllm_attention)
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
@ -187,6 +186,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@ -676,7 +676,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
# kv cache dtype to something different from query dtype.
|
||||
return False
|
||||
return use_cascade_attention(*args, **kwargs)
|
||||
# TODO: Cascade attention doesn't work, disable it for now
|
||||
# return use_cascade_attention(*args, **kwargs)
|
||||
return False
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
|
@ -88,6 +88,7 @@ class FlexAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
@ -286,6 +286,7 @@ class MLACommonBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@ -407,6 +408,7 @@ class MLACommonMetadata(Generic[D]):
|
||||
|
||||
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
A = TypeVar("A")
|
||||
|
||||
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
@ -930,7 +932,9 @@ def reorg_kvcache(
|
||||
return reorganized_kv_c_normed, reorganized_k_pe
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
|
||||
# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
|
||||
class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@ -956,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
indexer=None,
|
||||
q_pad_num_heads: Optional[int] = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
@ -974,8 +979,140 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
W_K = W_UK.transpose(0, 1) # 16 512 128
|
||||
W_V = W_UV.permute(1, 2, 0) # 16 128 512
|
||||
self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
|
||||
W_K, dtype=current_platform.fp8_dtype())
|
||||
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
|
||||
W_V, dtype=current_platform.fp8_dtype())
|
||||
|
||||
# The kernel operates on non-padded inputs. Hence, pre-compiling
|
||||
# triton kernel to avoid runtime compilation for unseen batch sizes
|
||||
# Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
|
||||
# On DS-R1, this step adds roughly 50s to the model loading time.
|
||||
max_batch_size = 1024 # [ToDo] Find the optimal upper limit
|
||||
pre_compilation_list = list(range(1, max_batch_size + 1))
|
||||
if is_global_first_rank():
|
||||
pre_compilation_list = tqdm(
|
||||
pre_compilation_list,
|
||||
desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
|
||||
total=max_batch_size,
|
||||
)
|
||||
|
||||
for m in pre_compilation_list:
|
||||
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_K.device)
|
||||
aiter_triton_fp8_bmm(x,
|
||||
self.W_K,
|
||||
self.W_K_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
|
||||
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
|
||||
dtype=torch.bfloat16,
|
||||
device=self.W_V.device)
|
||||
aiter_triton_fp8_bmm(x,
|
||||
self.W_V,
|
||||
self.W_V_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
else:
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.W_UV = W_UV.transpose(0, 1)
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = aiter_triton_fp8_bmm(x,
|
||||
self.W_V,
|
||||
self.W_V_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
# Copy result
|
||||
out.copy_(x)
|
||||
else:
|
||||
# Convert from (B, N * V) to (N, B, V)
|
||||
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
||||
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
out_new = out.transpose(0, 1).reshape(
|
||||
-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
# Adjust output buffer shape back to the original (B, N * V)
|
||||
N, B, V = out.shape
|
||||
out.resize_((B, N * V))
|
||||
out.copy_(out_new) # Copy result
|
||||
|
||||
|
||||
class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if use_flashinfer_prefill():
|
||||
logger.debug_once("Using FlashInfer prefill for MLA")
|
||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
||||
@ -1074,13 +1211,18 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
k, v, return_softmax_lse):
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
assert prefill.prefill_main is not None
|
||||
return prefill.prefill_main.run(
|
||||
ret = prefill.prefill_main.run(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
return_lse=return_softmax_lse,
|
||||
)
|
||||
|
||||
if isinstance(ret, tuple):
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||
return ret
|
||||
|
||||
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
|
||||
q, k, v, return_softmax_lse):
|
||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||
@ -1123,12 +1265,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
|
||||
chunk_idx: int, q, k, v):
|
||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||
return prefill.prefill_chunks[chunk_idx].run(
|
||||
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
return_lse=True,
|
||||
)
|
||||
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||
return attn_out, lse.transpose(0, 1).contiguous()
|
||||
|
||||
def _run_prefill_context_chunk_cudnn(self,
|
||||
prefill: MLACommonPrefillMetadata,
|
||||
@ -1154,36 +1298,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
True, #Indicates actual_seq_lens are on GPU or CPU.
|
||||
)
|
||||
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
# Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
|
||||
x = aiter_triton_fp8_bmm(x,
|
||||
self.W_V,
|
||||
self.W_V_scale,
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
# Copy result
|
||||
out.copy_(x)
|
||||
else:
|
||||
# Convert from (B, N * V) to (N, B, V)
|
||||
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
||||
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
out_new = out.transpose(0, 1).reshape(
|
||||
-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
# Adjust output buffer shape back to the original (B, N * V)
|
||||
N, B, V = out.shape
|
||||
out.resize_((B, N * V))
|
||||
out.copy_(out_new) # Copy result
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
@ -1455,6 +1569,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# TODO (zyongye): Prefill function here
|
||||
assert attn_metadata.prefill is not None
|
||||
assert self.dcp_world_size is not None
|
||||
|
||||
|
@ -177,6 +177,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# TODO: (zyongye) decode function for mla here
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
|
544
vllm/v1/attention/backends/mla/flashmla_sparse.py
Normal file
544
vllm/v1/attention/backends/mla/flashmla_sparse.py
Normal file
@ -0,0 +1,544 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.flashmla import (flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor:
|
||||
# Convert base-2 LSE to natural-log LSE
|
||||
# Keep FP32 for numerical stability during the merge.
|
||||
return (lse_base2.to(torch.float32) * math.log(2.0))
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
||||
return FlashMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||
return FlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if cache_dtype_str == "fp8_ds_mla":
|
||||
# custom storage fromat is 656 bytes
|
||||
# see FlashMLA readme.md for details
|
||||
return (num_blocks, block_size, 656)
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLASparsePrefillMetadata:
|
||||
# NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because
|
||||
# the kernel is not from flashmla
|
||||
block_table: torch.Tensor
|
||||
has_context: bool = False
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseDecodeAndContextMetadata:
|
||||
scheduler_metadata: torch.Tensor = None
|
||||
num_splits: torch.Tensor = None
|
||||
cache_lens: torch.Tensor = None
|
||||
prefill_context_lengths: Optional[torch.Tensor] = None
|
||||
prefill_new_k_start_locs: Optional[torch.Tensor] = None
|
||||
dummy_block_table: torch.Tensor = None
|
||||
|
||||
def filter_prefill_indices(
|
||||
self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.prefill_context_lengths is not None
|
||||
prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1)
|
||||
context_indices = torch.where(indices < prefill_context_lengths,
|
||||
indices, -1)
|
||||
new_token_indices = torch.where(indices >= prefill_context_lengths,
|
||||
indices - prefill_context_lengths, -1)
|
||||
return context_indices, new_token_indices
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: Optional[torch.Tensor]
|
||||
num_splits: torch.Tensor
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
fp8_extra_metadata: Optional[FP8KernelMetadata] = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = block_id < max_num_blocks_per_req
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
|
||||
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
|
||||
out_val = tl.where(is_invalid_tok | (~valid_block), -1,
|
||||
base * BLOCK_SIZE + inblock_off)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.
|
||||
Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, \
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \
|
||||
f"BLOCK_N ({BLOCK_N})"
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
self.topk_tokens_tensor = torch.tensor([self.topk_tokens],
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len],
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty((1, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
h_q, h_k = self.num_heads, 1
|
||||
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1))
|
||||
if current_platform.is_device_capability(100):
|
||||
max_num_sm_parts *= 2
|
||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
(max_num_sm_parts, 8),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.num_splits_buffer = torch.empty(
|
||||
# We pack all the tokens into one batch for sparse attention.
|
||||
# Otherwise, we can exceed the sm of `get_mla_metadata`.
|
||||
(
|
||||
2, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> FlashMLASparseMetadata:
|
||||
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\
|
||||
.copy_(torch.from_numpy(req_id_per_token), non_blocking=True)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
if self.use_fp8_kv_cache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Copy to persistent buffer for full-CG support
|
||||
tile_scheduler_metadata_buffer = \
|
||||
self.tile_scheduler_metadata_buffer[:num_sm_parts]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
self.num_splits_buffer.copy_(num_splits)
|
||||
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=self.num_splits_buffer,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=self.max_model_len_tensor,
|
||||
dummy_block_table=self.dummy_block_table)
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: Optional[torch.Tensor] = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(
|
||||
100) else 64
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1])
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.padding != 0:
|
||||
assert self.padding % self.num_heads == 0
|
||||
logger.warning_once(f"padding num_heads to {self.padding} \
|
||||
due to sparse attn kernel requirement")
|
||||
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
|
||||
q_padded[:, :self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices,
|
||||
self.softmax_scale)[0]
|
||||
output = output[:, :self.num_heads, :]
|
||||
return output
|
||||
|
||||
def _forward_fp8_kv(self, q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata) -> torch.Tensor:
|
||||
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
extra_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=extra_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=extra_metadata.cache_lens,
|
||||
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
|
||||
num_splits=extra_metadata.num_splits,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
return _attn_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
output_block_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for MLACommonImpl")
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
# TODO: handle index / kv_cache correctly
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global,
|
||||
attn_metadata)
|
||||
else:
|
||||
attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global,
|
||||
attn_metadata)
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
342
vllm/v1/attention/backends/mla/indexer.py
Normal file
342
vllm/v1/attention/backends/mla/indexer.py
Normal file
@ -0,0 +1,342 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return DeepseekV32IndexerMetadata
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 128]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
|
||||
return DeepseekV32IndexerMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
assert num_kv_heads == 1
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
schedule_metadata: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerMetadata:
|
||||
|
||||
# FIXME (zyongye)
|
||||
# hacky way to access the data now, need to be in chunked meta
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||
|
||||
|
||||
# TODO (zyongye) optimize this, this is now vibe coded
|
||||
def kv_spans_from_batches(
|
||||
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor,
|
||||
device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
||||
selected tokens per batch.
|
||||
Example: [0, 2, 4, 7] ->
|
||||
batch sizes (selected) [2, 2, 3], N=7 tokens total.
|
||||
seq_len_per_batch: 1D long tensor [B],
|
||||
full sequence length (KV length) of each batch.
|
||||
Example: [5, 9, 4].
|
||||
|
||||
Returns:
|
||||
start_tensor: 1D long tensor [N], start offset in the
|
||||
concatenated KV cache for each token's batch.
|
||||
end_location: 1D long tensor [N],
|
||||
**exclusive** end = start + token's local position.
|
||||
(So the attended KV slice is kv[start:end].)
|
||||
|
||||
Assumes each batch contributes its full `seq_len_per_batch[i]`
|
||||
keys to the KV cache, andthe selected tokens within a batch
|
||||
are the **last** `counts[i]` positions of that sequence.
|
||||
"""
|
||||
q = start_seq_loc.to(dtype=torch.long)
|
||||
L = seq_len_per_batch.to(dtype=torch.long)
|
||||
assert q.dim() == 1 and L.dim() == 1
|
||||
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
||||
|
||||
# Selected tokens per batch and totals
|
||||
counts = q[1:] - q[:-1] # [B]
|
||||
N = int(q[-1].item()) # total selected tokens
|
||||
B = L.numel()
|
||||
|
||||
if N == 0:
|
||||
return (torch.empty(0, dtype=torch.long, device=device),
|
||||
torch.empty(0, dtype=torch.long, device=device))
|
||||
|
||||
# KV start offsets per batch in the concatenated KV cache
|
||||
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
||||
|
||||
# For each selected token, which batch does it belong to?
|
||||
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
|
||||
|
||||
# Map batch KV start to each token
|
||||
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
||||
|
||||
# End-align local positions inside each batch:
|
||||
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
|
||||
L_expand = torch.repeat_interleave(L, counts) # [N]
|
||||
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
||||
# position within the selected block: 1..counts[b]
|
||||
pos_within = (torch.arange(N, dtype=torch.long) -
|
||||
torch.repeat_interleave(q[:-1], counts) + 1)
|
||||
|
||||
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
||||
end_location = start_tensor + local_pos # exclusive end
|
||||
|
||||
return start_tensor.int().to(device), end_location.int().to(device)
|
||||
|
||||
|
||||
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
return max_model_len * 2
|
||||
|
||||
|
||||
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
|
||||
max_prefill_buffer_size: int,
|
||||
reqs_start: int) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
|
||||
such that the total sequence length of each chunk is less than the
|
||||
maximum prefill buffer size.
|
||||
|
||||
Args:
|
||||
seq_lens_cpu: The sequence lengths of the prefill requests.
|
||||
max_prefill_buffer_size: The maximum prefill buffer size.
|
||||
reqs_start: The start index of the prefill requests.
|
||||
|
||||
Returns:
|
||||
A list of tuples of (reqs_start, reqs_end).
|
||||
"""
|
||||
chunk_seq_ids = []
|
||||
total_seq_lens = 0
|
||||
for i in range(reqs_start, len(seq_lens_cpu)):
|
||||
cur_seq_len = seq_lens_cpu[i].item()
|
||||
assert cur_seq_len <= max_prefill_buffer_size
|
||||
total_seq_lens += cur_seq_len
|
||||
if total_seq_lens > max_prefill_buffer_size:
|
||||
chunk_seq_ids.append((reqs_start, i))
|
||||
reqs_start = i
|
||||
total_seq_lens = cur_seq_len
|
||||
if total_seq_lens > 0:
|
||||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
|
||||
return chunk_seq_ids
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
scheduler_config = self.vllm_config.scheduler_config
|
||||
#NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
|
||||
self.max_prefill_buffer_size = get_max_prefill_buffer_size(
|
||||
self.vllm_config)
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config else 0)
|
||||
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
||||
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
||||
|
||||
props = torch.cuda.get_device_properties(self.device)
|
||||
sm_count = props.multi_processor_count
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
self.scheduler_metadata_buffer = torch.empty((self.num_sms + 1, 2),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
def build_one_prefill_chunk(self, reqs_start, reqs_end,
|
||||
query_start_loc_cpu, seq_lens_cpu,
|
||||
block_table):
|
||||
prefill_query_start_loc = query_start_loc_cpu[
|
||||
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
|
||||
self.device)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = torch.cat([
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
|
||||
]).to(torch.int32).to(self.device)
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start, reqs_end, query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks, )
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes])
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max()
|
||||
> decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms)
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.
|
||||
block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
@ -102,6 +102,7 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
padded_head_size = cdiv(
|
||||
head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
|
@ -360,6 +360,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
@ -68,6 +68,7 @@ class TreeAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
@ -171,6 +171,7 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
@ -106,6 +106,7 @@ class XFormersAttentionBackend(AttentionBackend):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
kv_cache_spec: The kv cache spec of each attention layer in the model
|
||||
"""
|
||||
|
||||
if is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
if is_kv_cache_spec_uniform(
|
||||
kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type(
|
||||
kv_cache_spec):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
@ -1128,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
use_mla=spec.use_mla,
|
||||
sliding_window=spec.sliding_window,
|
||||
)
|
||||
elif isinstance(spec, ChunkedLocalAttentionSpec):
|
||||
@ -1137,11 +1138,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
num_kv_heads=spec.num_kv_heads,
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
use_mla=spec.use_mla,
|
||||
attention_chunk_size=spec.attention_chunk_size,
|
||||
)
|
||||
|
||||
if not is_kv_cache_spec_uniform(kv_cache_spec):
|
||||
if not (is_kv_cache_spec_uniform(kv_cache_spec)
|
||||
or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)):
|
||||
raise ValueError("Hybrid KV cache manager is disabled but failed to "
|
||||
"convert the KV cache specs to one unified type.")
|
||||
|
||||
|
@ -10,7 +10,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec, FullAttentionSpec,
|
||||
KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
MLAAttentionSpec, SlidingWindowSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@ -656,6 +656,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
|
@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
return 2 * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
|
||||
@ -118,12 +115,13 @@ class FullAttentionSpec(AttentionSpec):
|
||||
if spec.sliding_window is not None)
|
||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||
if spec.attention_chunk_size is not None)
|
||||
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge")
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
use_mla=specs[0].use_mla,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
@ -140,6 +138,38 @@ class FullAttentionSpec(AttentionSpec):
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MLAAttentionSpec(FullAttentionSpec):
|
||||
# TODO(Lucas/Chen): less hacky way to do this
|
||||
cache_dtype_str: Optional[str] = None
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
if self.cache_dtype_str == "fp8_ds_mla":
|
||||
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
|
||||
# for details.
|
||||
return self.block_size * 656
|
||||
return self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"MLAAttentionSpec.")
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same "
|
||||
"quantization method.")
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
attention_chunk_size: int
|
||||
@ -163,9 +193,6 @@ class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.use_mla, "MLA is not supported for sliding window"
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
assert vllm_config.parallel_config.decode_context_parallel_size == 1, \
|
||||
"DCP not support sliding window."
|
||||
@ -266,9 +293,13 @@ class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
# Different block sizes, not uniform.
|
||||
return False
|
||||
one_spec = next(iter(kv_cache_specs.values()))
|
||||
if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)):
|
||||
if isinstance(one_spec, FullAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, type(one_spec))
|
||||
isinstance(spec, FullAttentionSpec)
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, CrossAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, CrossAttentionSpec)
|
||||
for spec in kv_cache_specs.values())
|
||||
elif isinstance(one_spec, SlidingWindowSpec):
|
||||
return all(
|
||||
|
@ -17,6 +17,7 @@ from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
@ -31,6 +32,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -47,10 +49,12 @@ class EagleProposer:
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
assert self.speculative_config is not None
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
|
||||
self.runner = runner
|
||||
self.device = device
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
@ -68,10 +72,15 @@ class EagleProposer:
|
||||
.is_multimodal_model
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
self.draft_indexer_metadata_builder: Optional[
|
||||
AttentionMetadataBuilder] = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
not self.vllm_config.model_config.enforce_eager
|
||||
and not self.speculative_config.enforce_eager)
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
@ -178,20 +187,30 @@ class EagleProposer:
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# Select the correct attention metadata builders for EAGLE layers.
|
||||
# Get the attention metadata builders once and reuse for later.
|
||||
builder = (self._get_attention_metadata_builder()
|
||||
if self.attn_metadata_builder is None else
|
||||
self.attn_metadata_builder)
|
||||
attn_metadata = builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0)
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||||
# FIXME: support hybrid kv for draft model (remove separate indexer)
|
||||
if self.draft_indexer_metadata_builder:
|
||||
draft_indexer_metadata = (
|
||||
self.draft_indexer_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
))
|
||||
else:
|
||||
draft_indexer_metadata = None
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
@ -222,8 +241,7 @@ class EagleProposer:
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp",
|
||||
"longcat_flash_mtp"):
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = last_hidden_states
|
||||
else:
|
||||
@ -323,7 +341,7 @@ class EagleProposer:
|
||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata = builder.build_for_drafting( # type: ignore
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
for layer_name in self.attn_layer_names:
|
||||
@ -352,8 +370,7 @@ class EagleProposer:
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method in ("deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp"):
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
@ -794,6 +811,10 @@ class EagleProposer:
|
||||
self.vllm_config.speculative_config.draft_model_config
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
|
||||
# FIXME: support hybrid kv for draft model
|
||||
target_indexer_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
DeepseekV32IndexerCache).keys())
|
||||
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
with set_model_tag("eagle_head"):
|
||||
@ -803,8 +824,25 @@ class EagleProposer:
|
||||
draft_attn_layer_names = (
|
||||
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
|
||||
target_attn_layer_names)
|
||||
|
||||
indexer_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
DeepseekV32IndexerCache)
|
||||
draft_indexer_layer_names = (indexer_layers.keys() -
|
||||
target_indexer_layer_names)
|
||||
self.attn_layer_names = list(draft_attn_layer_names)
|
||||
self.indexer_layer_names = list(draft_indexer_layer_names)
|
||||
|
||||
if self.indexer_layer_names:
|
||||
first_layer = self.indexer_layer_names[0]
|
||||
self.draft_indexer_metadata_builder = (
|
||||
indexer_layers[first_layer].get_attn_backend().get_builder_cls(
|
||||
)(
|
||||
indexer_layers[first_layer].get_kv_cache_spec(),
|
||||
self.indexer_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
))
|
||||
else:
|
||||
self.draft_indexer_metadata_builder = None
|
||||
|
||||
if supports_multimodal(target_model):
|
||||
# handle multimodality
|
||||
@ -888,10 +926,10 @@ class EagleProposer:
|
||||
def _get_attention_metadata_builder(
|
||||
self) -> list[AttentionMetadataBuilder]:
|
||||
"""Find and return the attention metadata builders for EAGLE layers.
|
||||
|
||||
|
||||
Returns:
|
||||
The metadata builders for EAGLE layers.
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If no metadata builders are found for EAGLE layers.
|
||||
"""
|
||||
|
@ -40,6 +40,7 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
@ -80,7 +81,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec,
|
||||
MambaSpec, MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
# yapf: enable
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
@ -2989,13 +2991,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
if self.parallel_config.enable_dbo and allow_microbatching:
|
||||
ubatch_slices, num_tokens_after_padding = ubatch_split(
|
||||
ubatch_slices, ubatch_num_tokens_after_padding = ubatch_split(
|
||||
num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
total_num_scheduled_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
vllm_config=self.vllm_config,
|
||||
)
|
||||
# Currently when DBO is enabled `ubatch_split` returns
|
||||
# the num_tokens_after_padding for a single ubatch, but we have 2
|
||||
# TODO(sage,lucas): this is cruft that should be addressed in the
|
||||
# padding refactor.
|
||||
if ubatch_num_tokens_after_padding is not None:
|
||||
num_tokens_after_padding = ubatch_num_tokens_after_padding * 2
|
||||
|
||||
# If we failed to microbatch, currently need to resynchronize
|
||||
# TODO(lucas,sage): we should be able to avoid this second sync by
|
||||
@ -3062,7 +3070,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata_i = (attn_group\
|
||||
.get_metadata_builder(ubatch_id=ubid)\
|
||||
.build_for_cudagraph_capture(common_attn_metadata))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
for layer_name in attn_group.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][
|
||||
layer_name] = attn_metadata_i
|
||||
@ -3070,7 +3078,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert type(attn_metadata) is dict
|
||||
attn_metadata_i = attn_group.get_metadata_builder()\
|
||||
.build_for_cudagraph_capture(common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
@ -3112,8 +3120,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
|
||||
BatchDescriptor(num_tokens=num_tokens,
|
||||
uniform_decode=uniform_decode))
|
||||
BatchDescriptor(num_tokens=num_tokens_after_padding,
|
||||
uniform_decode=uniform_decode)) \
|
||||
if not is_profile else (CUDAGraphMode.NONE, None)
|
||||
if cudagraph_runtime_mode is not None:
|
||||
# we allow forcing NONE when the dispatcher disagrees to support
|
||||
# warm ups for cudagraph capture
|
||||
@ -3125,7 +3134,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cudagraph_runtime_mode = _cg_mode
|
||||
|
||||
if ubatch_slices is not None:
|
||||
num_tokens = num_tokens // 2
|
||||
# Adjust values to reflect a single ubatch.
|
||||
# TODO(sage,lucas): this is cruft that should be addressed in
|
||||
# the padding refactor.
|
||||
num_tokens_after_padding = ubatch_slices[0].num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[:] = num_tokens_after_padding
|
||||
|
||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
@ -3810,8 +3825,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
num_blocks,
|
||||
kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size,
|
||||
cache_dtype_str=self.cache_config.cache_dtype)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = \
|
||||
@ -3997,7 +4015,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
@ -4007,8 +4024,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
dtype=self.kv_cache_dtype)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
self.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
@ -4030,6 +4046,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
@ -4049,13 +4066,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
assert not use_mla, "MLA is not supported for sliding" \
|
||||
"window"
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
sliding_window=attn_module.sliding_window)
|
||||
elif use_mla:
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=cache_dtype_str)
|
||||
elif self.attention_chunk_size is not None \
|
||||
and isinstance(attn_module, ChunkedLocalAttention):
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
@ -4063,22 +4088,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
use_mla=use_mla)
|
||||
attention_chunk_size=self.attention_chunk_size)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
dtype=self.kv_cache_dtype)
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
dtype=self.kv_cache_dtype)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
@ -4115,6 +4137,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
ds_indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache)
|
||||
for layer_name, ds_indexer_module in ds_indexer_layers.items():
|
||||
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
|
@ -530,7 +530,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=False,
|
||||
)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
@ -538,7 +537,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
|
Reference in New Issue
Block a user