Compare commits

...

25 Commits

Author SHA1 Message Date
e4beabd2c8 [BugFix] Fix default kv-cache-dtype default for DeepseekV3.2 (#25988)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:42 -07:00
febb688356 [Bugfix] Fix __syncwarp on ROCM (#25996)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:42 -07:00
a1825fe645 [MM] Add text-only mode for Qwen3-VL (#26000)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:42 -07:00
bab9231bf1 [Model] MTP fallback to eager for DeepSeek v32 (#25982)
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:38 -07:00
c214d699fd [spec decode] Consolidate speculative decode method name for MTP (#25232)
Signed-off-by: zixi-qi <qizixi@meta.com>
2025-09-30 22:47:11 -07:00
c3dfb0f6dd [Bench] Add DeepSeekV32 to MoE benchmark (#25962)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
83f3c9beae [bugfix][deepseek] fix flashmla kernel selection (#25956)
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
d0b178cef1 [NIXL] Add support for MLA caches with different latent dim (#25902)
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
b3230e1ac0 [New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Xiaozhu Meng <mxz297@gmail.com>
Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
03df0fb5d2 [BugFix] Fix DP/EP hang (#25906)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:10 -07:00
9471879bd4 [Bug] Fix Weight Loading for Block FP8 Cutlass SM90 (#25909)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:32:47 -07:00
ab5b6459df [Bugfix] Fallback ViT attn backend to SDPA for blackwell (#25851)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:32:47 -07:00
8ce5d3198d [P/D] NIXL Updates (#25844)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-28 22:55:33 -07:00
09c2cbc04a [Bugfix] fix Qwen3VLMoe load when pp > 1 (#25838)
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com>
Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-28 22:55:17 -07:00
4c347044c9 [VLM] Update Qwen3-VL max_num_video_tokens calculation for configurable video profiling (#25557)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:12 -07:00
19e7ab7315 [Bugfix] Fix Qwen3-VL regression from #24982 (#25814)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
6de3d431d9 [MM] Optimize memory profiling for scattered multimodal embeddings (#25810)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
b14773bd64 [Bugfix][NIXL] Fix Async Scheduler timeout issue (#25808)
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
26a7a33b88 [Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:03 -07:00
5aa5811a16 [CI] Fix FlashInfer AOT in release docker image (#25730)
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
c2fa2d4dc9 [Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (#25788)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
32335c8b34 Add option to restrict media domains (#25783)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
04c2b26972 Add filtering for chat template kwargs (#25794)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
ee10d7e6ff Validate API tokens in constant time (#25781)
Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
bb79c4da2f Reduce the Cuda Graph memory footprint when running with DBO (#25779)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
122 changed files with 5413 additions and 821 deletions

View File

@ -76,7 +76,7 @@ steps:
queue: arm64_cpu_queue_postmerge
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
# Add job to create multi-arch manifest

View File

@ -584,8 +584,9 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
elif config.architectures[0] in (
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"Glm4MoeForCausalLM",
):
E = config.n_routed_experts

View File

@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with
# sm90a
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(SUPPORT_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
list(APPEND SUPPORT_ARCHS 9.0a)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
list(APPEND SUPPORT_ARCHS 10.0a)
endif()
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
if(FLASH_MLA_ARCHS)
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
)
set(FlashMLA_Extension_SOURCES
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc)
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)
set(FlashMLA_Extension_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_Extension_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
define_gpu_extension_target(
_flashmla_C
DESTINATION vllm
@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
USE_SABI 3
WITH_SOABI)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
define_gpu_extension_target(
_flashmla_extension_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${FlashMLA_Extension_SOURCES}
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
USE_SABI 3
WITH_SOABI)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_extension_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
else()
# Create an empty target for setup.py when not targeting sm90a systems
# Create empty targets for setup.py when not targeting sm90a systems
add_custom_target(_flashmla_C)
add_custom_target(_flashmla_extension_C)
endif()

View File

@ -56,3 +56,11 @@ void cp_gather_cache(
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);

View File

@ -16,6 +16,7 @@
#include <algorithm>
#include <cassert>
#include <cfloat> // FLT_MIN
#include <map>
#include <vector>
@ -396,6 +397,180 @@ __global__ void concat_and_cache_mla_kernel(
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_ds_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int64_t dst_idx_start =
block_idx * block_stride + block_offset * entry_stride;
// Create 4 tile scales in shared memory
__shared__ float smem[20];
float* shard_abs_max = smem;
float* tile_scales = smem + 16;
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
// Cast kv_cache to 16_bit for RoPE values
scalar_t* kv_cache_16bit =
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
// The last 64 threads handle the RoPE part
if (threadIdx.x >= kv_lora_rank) {
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
// RoPE values start after the packed 8-bit NoPE values and the
// 32-bit scales
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
kv_cache_16bit[dst_idx] = k_pe[src_idx];
return;
}
// Determine the scale for each chunk of NoPE
const int16_t tile_idx = threadIdx.x >> 7;
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
const int16_t lane_idx = threadIdx.x & 31;
// Load the NoPE element for this thread into registers
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
const scalar_t src_val = kv_c[src_idx];
// Warp-level reduction to find the max absolute value in the warp
float max_abs = fabsf(src_val);
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
#ifdef USE_ROCM
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
#else
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
#endif
}
// The first lane of each warp in each tile writes the max_abs of this part
// of the tile to shared memory
if (lane_idx == 0) {
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
}
__syncthreads();
// The first lane of the first warp in each tile computes the scale for the
// tile and writes it to shared memory and to kv_cache
if (warp_idx == 0 && lane_idx == 0) {
float4 shard_abs_max_vec =
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
448.f;
// Avoid division by zero in `scaled_convert`
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
}
__syncthreads();
// Now all threads in the block scale and write their element
const float scale_val = tile_scales[tile_idx];
const int64_t dst_idx = dst_idx_start + threadIdx.x;
kv_cache[dst_idx] =
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
src_val, scale_val);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void indexer_k_quant_and_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
float amax = 0.0f;
for (int i = 0; i < VEC_SIZE; i++) {
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
}
#ifndef USE_ROCM
__syncwarp();
#endif
// Reduced amax
for (int mask = 16; mask > 0; mask /= 2) {
#ifdef USE_ROCM
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
#else
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
#endif
}
#ifndef USE_ROCM
__syncwarp();
#endif
float scale = fmaxf(amax, 1e-4) / 448.0f;
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
}
if (threadIdx.x == 0) {
const int64_t dst_scale_idx =
block_idx * cache_block_size * cache_stride +
cache_block_size * head_dim +
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
}
}
} // namespace vllm
// KV_T is the data type of key and value tensors.
@ -438,7 +613,7 @@ void reshape_and_cache(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE)
CALL_RESHAPE_AND_CACHE);
}
// KV_T is the data type of key and value tensors.
@ -509,6 +684,18 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
@ -531,20 +718,44 @@ void concat_and_cache_mla(
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
if (kv_cache_dtype == "fp8_ds_mla") {
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
TORCH_CHECK(kv_c.itemsize() == 2,
"kv_c.itemsize() must be 2 for fp8_ds_mla");
TORCH_CHECK(k_pe.itemsize() == 2,
"k_pe.itemsize() must be 2 for fp8_ds_mla");
} else {
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
}
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
if (kv_cache_dtype == "fp8_ds_mla") {
dim3 grid(num_tokens);
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
dim3 block(576);
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_DS_MLA);
} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}
}
namespace vllm {
@ -922,3 +1133,42 @@ void cp_gather_cache(
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
cache_block_size, cache_stride, use_ue8m0);
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
(quant_block_size * vec_size));
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
CALL_INDEXER_K_QUANT_AND_CACHE);
}

View File

@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \

View File

@ -713,6 +713,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
}
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {

View File

@ -404,6 +404,9 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
fi
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
uv pip install --system cuda-python==$(echo $CUDA_VERSION | cut -d. -f1,2) pynvml==$(echo $CUDA_VERSION | cut -d. -f1) nvidia-nvshmem-cu$(echo $CUDA_VERSION | cut -d. -f1)
# Build AOT kernels
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
python3 -m flashinfer.aot

View File

@ -6,6 +6,10 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
!!! tip
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
## Offline Inference
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:

View File

@ -60,6 +60,12 @@ Key points from the PyTorch security guide:
- Implement proper authentication and authorization for management interfaces
- Follow the principle of least privilege for all system components
### 4. **Restrict Domains Access for Media URLs:**
Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
## Security and Firewalls: Protecting Exposed vLLM Systems
While vLLM is designed to allow unsafe network services to be isolated to

View File

@ -54,6 +54,7 @@ def parse_args():
"--method",
type=str,
default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
)
parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5)
@ -118,9 +119,9 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method.endswith("mtp"):
elif args.method == "mtp":
speculative_config = {
"method": args.method,
"method": "mtp",
"num_speculative_tokens": args.num_spec_tokens,
}
else:

View File

@ -322,6 +322,8 @@ class precompiled_wheel_utils:
"vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so",
"vllm/_flashmla_extension_C.abi3.so",
"vllm/_sparse_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/cumem_allocator.abi3.so",
@ -589,6 +591,8 @@ if _is_cuda():
# not targeting a hopper system
ext_modules.append(
CMakeExtension(name="vllm._flashmla_C", optional=True))
ext_modules.append(
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops():

View File

@ -191,7 +191,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_dtype,
use_mla=False,
),
layer_names=[self.attn.layer_name],
vllm_config=self.vllm_config,

View File

@ -45,6 +45,7 @@ class MockModelConfig:
logits_processor_pattern: Optional[str] = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None
generation_config: str = "auto"
skip_tokenizer_init: bool = False

View File

@ -240,6 +240,7 @@ class MockModelConfig:
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)

View File

@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id", "add_generation_prompt",
"continue_final_message", "tools"
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking", "add_generation_prompt",
"continue_final_message", "tools"
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = ([{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}])
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json`
# yapf: disable

View File

@ -593,6 +593,119 @@ def test_concat_and_cache_mla(
torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_cache_slice = ref_cache[block_idx, block_offset]
ref_cache_16bit = ref_cache_slice.view(dtype)
ref_cache_32bit = ref_cache_slice.view(torch.float32)
kv_c_data = kv_c[i]
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = (tile_idx + 1) * 128
tile_data[:] = kv_c_data[tile_start:tile_end]
# tile_scale = tile_data.amax().to(torch.float32) / 448.
# NOTE: Using torch's amax() gives different results,
# so this must be manually computed.
tile_data_float = tile_data.to(torch.float32)
manual_max = abs(tile_data_float[0])
for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448.
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8")
for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache_slice = kv_cache[block_idx, block_offset]
ref_cache_slice = ref_cache[block_idx, block_offset]
kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)

View File

@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
fp8_paged_mqa_logits, get_num_sms,
get_paged_mqa_logits_metadata)
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
# x: (num_blocks, block_size, 1, head_dim)
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
x_fp8 = torch.empty(
(num_blocks, block_size * (head_dim + 4)),
device=x.device,
dtype=torch.uint8,
)
x_fp8[:, :block_size * head_dim] = x_scaled.view(
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
x_fp8[:,
block_size * head_dim:] = sf.view(num_blocks,
block_size).view(dtype=torch.uint8)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
def per_custom_dims_cast_to_fp8(
x: torch.Tensor, dims: tuple,
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
chunk_size = seq_len // 2
cp_size = seq_len_kv // seq_len
cp_id = cp_size // 3
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
for i in range(chunk_size):
ke[i] = cp_id * chunk_size + i
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
return ks, ke
def _ref_fp8_mqa_logits(
q: torch.Tensor,
kv: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
):
seq_len_kv = kv.shape[0]
k = kv
q = q.float()
k = k.float()
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
score = torch.einsum("mhd,and->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
def test_deepgemm_fp8_mqa_logits():
torch.manual_seed(0)
random.seed(0)
num_heads, head_dim = 32, 128
for seq_len in (512, ):
for seq_len_kv in (1024, ):
for disable_cp in (False, True):
q = torch.randn(
seq_len,
num_heads,
head_dim,
device="cuda",
dtype=torch.bfloat16,
)
kv = torch.randn(seq_len_kv,
head_dim,
device="cuda",
dtype=torch.bfloat16)
weights = torch.randn(seq_len,
num_heads,
device="cuda",
dtype=torch.float32)
if disable_cp:
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.arange(seq_len, dtype=torch.int,
device="cuda") + (seq_len_kv - seq_len)
else:
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
ref_logits = _ref_fp8_mqa_logits(
q=q,
kv=kv,
weights=weights,
cu_seqlen_ks=ks,
cu_seqlen_ke=ke,
)
ref_neginf_mask = ref_logits == float("-inf")
neginf_mask = logits == float("-inf")
assert torch.equal(neginf_mask, ref_neginf_mask)
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"
def _ref_fp8_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
):
batch_size, next_n, _, _ = q.size()
_, block_size, _, _ = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
context_lens_list = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens_list[i]
q_offsets = torch.arange(context_len - next_n,
context_len,
device="cuda")
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
0, 1).contiguous())
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(
block_rk * block_size,
(block_rk + 1) * block_size,
device="cuda",
)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
<= q_offsets[:, None])
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n:(i + 1) * next_n,
block_rk * block_size:(block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
def test_deepgemm_fp8_paged_mqa_logits():
torch.manual_seed(0)
random.seed(0)
max_model_len = 4096
for batch_size, next_n in [(4, 1), (2, 2)]:
for heads, index_dim in [(32, 128)]:
for avg_kv in (2048, ):
num_blocks, blocksize = max_model_len * 2, 64
q = torch.randn(
(batch_size, next_n, heads, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
kv_cache = torch.randn(
(num_blocks, blocksize, 1, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
weights = torch.randn(
(batch_size * next_n, heads),
device="cuda",
dtype=torch.float32,
)
context_lens = (torch.randint(int(0.8 * avg_kv),
int(1.2 * avg_kv),
(batch_size, )).cuda().to(
torch.int32))
max_block_len = ((context_lens.max().item() + blocksize - 1) //
blocksize * blocksize)
block_tables = torch.zeros(
(batch_size, max_block_len),
device="cuda",
dtype=torch.int32,
)
counter = 0
block_idx_pool = list(range(num_blocks))
random.shuffle(block_idx_pool)
for i in range(batch_size):
ctx_len = int(context_lens[i].item())
for j in range((ctx_len + blocksize - 1) // blocksize):
block_tables[i][j] = block_idx_pool[counter]
counter += 1
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms())
logits = fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
schedule_metadata,
max_model_len,
)
ref_logits = _ref_fp8_paged_mqa_logits(
q,
kv_cache,
weights,
context_lens,
block_tables,
max_model_len,
)
positions = (torch.arange(max_model_len,
device="cuda").unsqueeze(0).expand(
batch_size * next_n, -1))
row_indices = (
torch.arange(batch_size * next_n, device="cuda") // next_n)
next_n_offset = (
torch.arange(batch_size * next_n, device="cuda") % next_n)
mask = positions <= (context_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
logits = logits.masked_fill(~mask, 0)
ref_logits = ref_logits.masked_fill(~mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"

View File

@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
descale_k = None
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)
return flash_mla_with_kvcache(q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k)
def scaled_dot_product_attention(query, key, value, is_causal=False):
query = query.float()

View File

@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
def _cuda_sm90_available() -> bool:
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major == 9
def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 128
num_heads_k = 1
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
topk = 128
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
assert tile_md.dtype == torch.int32
assert num_splits.dtype == torch.int32
def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 1
head_dim_k = 576
head_dim_v = 512
num_heads_k = 1
page_block_size = 64
bytes_per_token = 656
topk = 128
# Metadata
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
# q_heads_per_hk = num_heads_q // num_heads_k
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
# Inputs
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
dtype=torch.bfloat16,
device=device)
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
dtype=torch.uint8,
device=device)
indices = torch.zeros((batch_size, seqlen_q, topk),
dtype=torch.int32,
device=device)
block_table = torch.zeros((batch_size, 128),
dtype=torch.int32,
device=device)
out, lse = fm.flash_mla_with_kvcache(q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_md,
num_splits,
indices=indices,
is_fp8_kvcache=True)
assert out.shape[0] == batch_size
assert out.shape[-1] == head_dim_v
assert lse.shape[0] == batch_size
def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
s_q = 1
s_kv = 1
h_q = 64 # kernel expects multiple of 64
h_kv = 1
d_qk = 576
d_v = 512
topk = 128
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
d_v)
assert out.shape == (s_q, h_q, d_v)
assert max_logits.shape == (s_q, h_q)
assert lse.shape == (s_q, h_q)

View File

@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.testing import assert_close
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
def test_pack_seq_basic_fp8():
"""Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test cases with 3D tensors (N, H, D)
test_cases = [
(6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4)
(10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8)
(20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32)
]
for N, H, D, B, lengths_list in test_cases:
# Create input tensor with small values for fp8
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor(lengths_list, device=device)
# Pack the data
packed = pack_seq_triton(x, lengths)
# Check output shape and properties
expected_shape = (B, max(lengths_list), H, D)
assert packed.shape == expected_shape
assert packed.dtype == dtype
assert packed.device == x.device
# Check that valid data is preserved (within fp8 precision)
for b in range(B):
start_idx = sum(lengths_list[:b])
seq_len = lengths_list[b]
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
actual_data = packed[b, :seq_len].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
def test_pack_seq_custom_padding_fp8():
"""Test pack_seq_triton with custom padding values for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 20, 8, 16, 2
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
# Test with different padding values
for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
result = pack_seq_triton(x, lengths, pad_value=pad_value)
# Check valid data
for b in range(B):
start_idx = b * 10
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
actual_data = result[b, :10].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
# Check padding (fp8 has limited range, so check for large values)
padded_data = result[:, 10:].to(torch.float32)
if pad_value < 0:
assert torch.all(padded_data < -50) # Large negative values
elif pad_value > 0:
assert torch.all(padded_data > 50) # Large positive values
else:
assert torch.allclose(padded_data,
torch.zeros_like(padded_data),
atol=1e-2)
def test_pack_seq_default_negative_inf_padding_fp8():
"""Test that pack_seq_triton uses -inf padding by default for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# B = 2
N, H, D = 20, 8, 16
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
result = pack_seq_triton(x, lengths)
# Check that padding is large negative values (fp8 representation of -inf)
padded_data = result[:, 10:].to(torch.float32)
assert torch.all(
padded_data < -100) # fp8 -inf is represented as large negative number
def test_pack_seq_edge_cases_fp8():
"""Test pack_seq_triton with edge cases for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test with single batch element
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([10], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (1, 10, 8, 16)
# Test with very short sequences
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([1, 1, 1], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (3, 1, 4, 8)
# Test with different sequence lengths
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([5, 7, 3], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (3, 7, 8, 16)
def test_pack_seq_different_block_sizes_fp8():
"""Test pack_seq_triton with different block sizes for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 100, 16, 32, 4
lengths = torch.tensor([25, 25, 25, 25], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
# Test different block sizes
for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
assert result.shape == (B, 25, H, D)
# Check that valid data is preserved (within fp8 precision)
for b in range(B):
start_idx = b * 25
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
actual_data = result[b, :25].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
def test_pack_seq_shape_consistency():
"""Test that pack_seq_triton maintains shape consistency."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 20, 8, 16, 2
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
result = pack_seq_triton(x, lengths)
# Check shape consistency
assert result.shape[0] == B # Batch dimension
assert result.shape[1] == lengths.max().item() # Max sequence length
assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved
def test_pack_unpack_roundtrip_fp8():
"""Test that pack -> unpack gives us back the original data for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test cases with 3D tensors
test_cases = [
(6, 8, 4, 2, [3, 3]),
(10, 4, 8, 3, [2, 4, 4]),
(20, 16, 32, 4, [5, 5, 5, 5]),
(15, 8, 16, 3, [7, 5, 3]),
]
for N, H, D, B, lengths_list in test_cases:
# Create input tensor with small values for fp8
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor(lengths_list, device=device)
# Pack the data
packed = pack_seq_triton(x, lengths)
# Unpack the data
unpacked = unpack_seq_triton(packed, lengths)
# Check that we get back the original data (within fp8 precision)
assert unpacked.shape == x.shape
x_f32 = x.to(torch.float32)
unpacked_f32 = unpacked.to(torch.float32)
assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
# Unpack without explicit start locations (computed in kernel)
unpacked_with_loc = unpack_seq_triton(packed, lengths)
assert_close(x_f32,
unpacked_with_loc.to(torch.float32),
rtol=1e-3,
atol=1e-2)
def test_unpack_seq_triton_edge_cases_fp8():
"""Test unpack function with edge cases for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test with single batch element
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([10], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
# Test with very short sequences
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([1, 1, 1], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
# Only compare the first 3 elements that were actually packed
assert_close(x[:3].to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([5, 7, 3], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)

View File

@ -207,6 +207,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
min_transformers_version="4.54"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",

View File

@ -8,7 +8,8 @@ import pytest
from vllm import LLM
from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
get_kv_cache_configs)
from vllm.v1.engine.core import EngineCore as V1EngineCore
from ..utils import create_new_process_for_each_test
@ -62,11 +63,13 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# Avoid calling model.forward()
def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_configs(
kv_cache_configs = get_kv_cache_configs(
vllm_config,
kv_cache_specs,
[10 * GiB_bytes],
)[0]
)
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
kv_cache_configs)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config

View File

@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
raw_image_url: str, suffix: str):
connector = MediaConnector()
connector = MediaConnector(
# Domain restriction should not apply to data URLs.
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
url_image = url_images[raw_image_url]
try:
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
modality_idxs = argsort_mm_positions(mm_positions)
assert modality_idxs == expected_modality_idxs
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_allowed_media_domains(video_url: str, num_frames: int):
connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}},
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
with pytest.raises(ValueError):
_, _ = connector.fetch_video(disallowed_url)
with pytest.raises(ValueError):
_, _ = await connector.fetch_video_async(disallowed_url)

View File

@ -26,5 +26,5 @@ class DummyPlatform(Platform):
def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink):
has_sink, use_sparse):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
from typing import Optional, Union
import pytest
import torch
@ -10,6 +11,7 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache(
device: torch.device,
num_blocks: int,
common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True) -> torch.Tensor:
randomize_blocks: bool = True,
kv_cache_dtype: Optional[str] = None,
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
"""Create and prepopulate an MLA KV cache with context data.
Args:
@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks
or use sequential order
kv_cache_dtype: Optional kv cache dtype string. When set to
"fp8_ds_mla" the cache is populated using the
fp8 DeepSeek MLA layout via concat_and_cache_mla.
scale: Scaling factor forwarded to concat_and_cache_mla when the
fp8 cache layout is requested.
Returns:
MLA KV cache tensor
@ -105,23 +114,61 @@ def create_and_prepopulate_kv_cache(
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# Create MLA KV cache: (num_blocks, block_size, head_size)
kv_cache = torch.empty(num_blocks,
block_size,
head_size,
dtype=dtype,
device=device)
kv_cache_flat = kv_cache.view(-1, head_size)
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
if use_fp8_ds_mla:
if not kv_c_contexts:
raise ValueError("kv_c_contexts cannot be empty when using"
" fp8_ds_mla cache dtype")
kv_lora_rank = kv_c_contexts[0].shape[-1]
rope_dim = k_pe_contexts[0].shape[-1]
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
kv_cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
device=device)
scale_tensor = (scale
if isinstance(scale, torch.Tensor) else torch.tensor(
scale, dtype=torch.float32, device=device))
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
else:
# Create MLA KV cache: (num_blocks, block_size, head_size)
kv_cache = torch.empty(num_blocks,
block_size,
head_size,
dtype=dtype,
device=device)
kv_cache_flat = kv_cache.view(-1, head_size)
# Populate the cache with the context tokens
# Start from block_id=1 since block_id=0 is considered the null block
start_block_idx = 1
for i in range(batch_size):
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
context_len = kv_c_context.shape[0]
if context_len == 0:
start_block_idx += cdiv(int(seq_lens[i]), block_size)
continue
start = start_block_idx * block_size
end = start + kv_context.shape[0]
kv_cache_flat[start:end, ...] = kv_context
if use_fp8_ds_mla:
slots = torch.arange(context_len, device=device,
dtype=torch.long) + start
ops.concat_and_cache_mla(
kv_c_context,
k_pe_context.squeeze(1),
kv_cache,
slots,
kv_cache_dtype="fp8_ds_mla",
scale=scale_tensor,
)
else:
kv_context = torch.cat(
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
end = start + kv_context.shape[0]
kv_cache_flat[start:end, ...] = kv_context
# Stay block aligned and allocate enough blocks for the new tokens
start_block_idx += cdiv(int(seq_lens[i]), block_size)

View File

@ -0,0 +1,426 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the FlashMLA sparse backend utilities."""
import math
from types import MethodType, SimpleNamespace
import numpy as np
import pytest
import torch
from tests.v1.attention.test_mla_backends import (
BATCH_SPECS, BatchSpec, MockAttentionLayer,
create_and_prepopulate_kv_cache)
from tests.v1.attention.utils import (create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config)
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl, FlashMLASparseMetadata)
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
for name in [
"mixed_small",
"mixed_medium",
"small_prefill",
"medium_prefill",
"single_prefill",
]
}
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
query_lens=[256] * 2)
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens=[256] * 2, query_lens=[256] * 2)
def _dequantize_fp8_ds_mla_entry(
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
# The first kv_lora_rank bytes store FP8 latent values with one scale per
# 128 element tile written as float32 right after the latent payload.
scales = cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
latent = torch.empty(kv_lora_rank,
dtype=torch.float16,
device=cache_slice.device)
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = tile_start + 128
ops.convert_fp8(latent[tile_start:tile_end],
cache_slice[tile_start:tile_end],
float(scales[tile_idx].item()),
kv_dtype="fp8")
latent = latent.to(dtype)
rope_offset = kv_lora_rank // 2 + 8
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
return latent, rope_vals.clone()
def _quantize_dequantize_fp8_ds_mla(
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
if kv_c.numel() == 0:
return kv_c.clone(), k_pe.clone()
kv_lora_rank = kv_c.shape[-1]
rope_dim = k_pe.shape[-1]
num_tokens = kv_c.shape[0]
num_blocks = max(1, math.ceil(num_tokens / block_size))
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
tmp_cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
device=kv_c.device)
slot_mapping = torch.arange(num_tokens,
dtype=torch.long,
device=kv_c.device)
ops.concat_and_cache_mla(kv_c,
k_pe,
tmp_cache,
slot_mapping,
kv_cache_dtype="fp8_ds_mla",
scale=scale)
dequant_kv_c = torch.empty_like(kv_c)
dequant_k_pe = torch.empty_like(k_pe)
for token_idx in range(num_tokens):
slot = slot_mapping[token_idx].item()
block_idx = slot // block_size
block_offset = slot % block_size
cache_slice = tmp_cache[block_idx, block_offset]
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
dequant_kv_c[token_idx] = latent
dequant_k_pe[token_idx] = rope_vals
return dequant_kv_c, dequant_k_pe
def test_sparse_backend_metadata_registration():
backend = FlashMLASparseBackend
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
assert backend.get_metadata_cls() is FlashMLASparseMetadata
assert backend.get_impl_cls() is FlashMLASparseImpl
dtype_list = backend.get_supported_dtypes()
assert torch.bfloat16 in dtype_list
shape = backend.get_kv_cache_shape(num_blocks=2,
block_size=64,
num_kv_heads=1,
head_size=576)
assert shape == (2, 64, 576)
def test_sparse_decode_metadata_filters_prefill_indices():
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
metadata = FlashMLASparseDecodeAndContextMetadata(
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
num_splits=torch.tensor([1, 1], dtype=torch.int32),
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
prefill_context_lengths=prefill_context_lengths,
)
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
context_indices, new_token_indices = metadata.filter_prefill_indices(
indices)
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
dtype=torch.int32)
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
dtype=torch.int32)
assert torch.equal(context_indices, expected_context)
assert torch.equal(new_token_indices, expected_new_tokens)
def test_sparse_impl_zero_fills_when_metadata_missing():
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
dummy_layer = object()
q = torch.zeros((2, 1, 3))
k_c = torch.zeros((2, 3))
k_pe = torch.zeros((2, 1, 1))
kv_cache = torch.zeros((1, 1, 1))
output = torch.ones((2, 4))
result = FlashMLASparseImpl.forward(impl,
dummy_layer,
q,
k_c,
k_pe,
kv_cache,
attn_metadata=None,
output=output)
assert result is output
assert torch.all(result == 0)
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name,
kv_cache_dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")
device = torch.device("cuda")
dtype = torch.bfloat16
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
# Model hyper-parameters (kept intentionally small for the unit test)
num_heads = 128
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
v_head_dim = 128
head_size = kv_lora_rank + qk_rope_head_dim
topk_tokens = 2048
max_seqlen = max(batch_spec.seq_lens)
total_cache_tokens = sum(batch_spec.seq_lens)
block_size = 64
vllm_config = create_vllm_config(
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
max_model_len=max_seqlen,
num_gpu_blocks=max(2048,
cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size)
model_config = vllm_config.model_config
model_config.hf_config = SimpleNamespace(
attn_module_list_cfg=[{
"topk_tokens": topk_tokens
}])
model_config.hf_text_config = SimpleNamespace(
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
model_type="deepseek_v2",
)
model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: num_heads, model_config)
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
model_config)
model_config.get_head_size = MethodType(lambda self: head_size,
model_config)
model_config.get_sliding_window = MethodType(lambda self: None,
model_config)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
torch.manual_seed(0)
scale = 1.0 / math.sqrt(head_size)
# Shared MLA projection weights to keep reference and backend in sync
W_UK = torch.randn(kv_lora_rank,
num_heads,
qk_nope_head_dim,
dtype=dtype,
device=device)
W_UV = torch.randn(kv_lora_rank,
num_heads,
v_head_dim,
dtype=dtype,
device=device)
# Build synthetic decode-only workload
seq_lens = batch_spec.seq_lens
query_lens = batch_spec.query_lens
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
kv_c_contexts, k_pe_contexts = [], []
reference_outputs = []
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
for i in range(batch_spec.batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
ctx_len = s_len - q_len
q_c = torch.rand(q_len,
num_heads,
qk_nope_head_dim + qk_rope_head_dim,
dtype=dtype,
device=device)
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
k_pe_full = torch.rand(s_len,
1,
qk_rope_head_dim,
dtype=dtype,
device=device)
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
kv_c_full,
k_pe_full.squeeze(1),
block_size=vllm_config.cache_config.block_size,
scale=kv_cache_scale,
)
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, ctx_len:] = causal_mask
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
all_q_vllm.append(q_c)
all_kv_c_vllm.append(kv_c_full[ctx_len:])
all_k_pe_vllm.append(k_pe_full[ctx_len:])
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_reference = torch.cat(reference_outputs, dim=0)
vllm_config.cache_config.cache_dtype = kv_cache_dtype
common_attn_metadata = create_common_attn_metadata(
batch_spec,
vllm_config.cache_config.block_size,
device,
arange_block_indices=True)
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=vllm_config.cache_config.block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=False,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
scale=kv_cache_scale,
)
builder_cls = FlashMLASparseBackend.get_builder_cls()
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
metadata = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
starts[:-1], seg_lengths)
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
prefix_lengths = seq_lengths - seg_lengths
positions += np.repeat(prefix_lengths, seg_lengths)
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
topk = metadata.topk_tokens
debug_indices = torch.arange(topk, device=device,
dtype=torch.int32).unsqueeze(0)
token_positions = pos_gpu.unsqueeze(1)
causal_mask = (debug_indices <= token_positions)
debug_indices = torch.where(causal_mask, debug_indices,
torch.full_like(debug_indices, -1))
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
# buffer, so emulate that contract with a simple namespace mock.
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
-1).clone()
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
ok, reason = flashmla.is_flashmla_supported()
if not ok:
pytest.skip(reason)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
output_size=num_heads *
(qk_nope_head_dim + v_head_dim),
bias=False).to(device=device,
dtype=dtype)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls()
impl = impl_cls(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer)
impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device)
out_buffer = torch.empty(metadata.num_actual_tokens,
num_heads * v_head_dim,
dtype=dtype,
device=device)
backend_output = impl.forward(layer,
query_vllm,
kv_c_vllm,
k_pe_vllm,
kv_cache,
metadata,
output=out_buffer)
assert backend_output.shape == sdpa_reference.shape
assert backend_output.dtype == sdpa_reference.dtype
assert torch.isfinite(backend_output).all()
torch.testing.assert_close(backend_output,
sdpa_reference,
rtol=0.5,
atol=0.5)

View File

@ -168,7 +168,6 @@ def create_standard_kv_cache_spec(
vllm_config.parallel_config),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
use_mla=vllm_config.model_config.use_mla,
sliding_window=vllm_config.model_config.get_sliding_window(),
)

View File

@ -24,7 +24,8 @@ from vllm.v1.core.kv_cache_utils import (
make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec,
KVCacheTensor, MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False,
sliding_window=None):
return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window)
@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False,
sliding_window=1):
return SlidingWindowSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window)
@ -894,7 +891,6 @@ def test_merge_kv_cache_spec():
num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size,
dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1,
),
]
@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len,
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
)
# Estimate the maximum model length, 16384 model_len need 8GB
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config():
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
)
sliding_window_spec = SlidingWindowSpec(
@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config():
num_kv_heads=32,
head_size=128,
dtype=torch.float16,
use_mla=False,
sliding_window=1024,
)
@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config():
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
],
)
def new_mla_spec(cache_dtype_str=None):
return MLAAttentionSpec(block_size=16,
num_kv_heads=16,
head_size=64,
dtype=torch.float32,
cache_dtype_str=cache_dtype_str)
def test_merge_mla_spec():
kv_cache_specs = [
new_mla_spec(),
new_mla_spec(),
]
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
assert mla_spec == new_mla_spec()
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
]
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla")
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_mla_spec(cache_dtype_str=None),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
kv_cache_specs = [
new_kv_cache_spec(),
new_mla_spec(),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_kv_cache_spec(),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)

View File

@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
kv_cache_groups=[
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
FullAttentionSpec(block_size, 1, 1, torch.float32),
)
],
)
@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
kv_cache_groups=[
KVCacheGroupSpec(
["layer1"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
FullAttentionSpec(block_size, 1, 1, torch.float32),
),
KVCacheGroupSpec(
["layer2"],
@ -98,7 +98,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
1,
1,
torch.float32,
False,
sliding_window=2 * block_size),
),
KVCacheGroupSpec(
@ -107,7 +106,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
1,
1,
torch.float32,
False,
sliding_window=2 * block_size),
),
],
@ -1338,7 +1336,6 @@ def test_eagle_with_sliding_window():
head_size=1,
dtype=torch.float32,
sliding_window=block_size,
use_mla=False,
)
manager = KVCacheManager(
KVCacheConfig(

View File

@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix():
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
@ -100,7 +99,6 @@ def test_sliding_window_possible_cached_prefix():
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
@ -165,7 +163,6 @@ def test_chunked_local_attention_remove_skipped_blocks():
head_size=1,
dtype=torch.float32,
attention_chunk_size=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
@ -217,7 +214,6 @@ def test_sliding_window_remove_skipped_blocks():
head_size=1,
dtype=torch.float32,
sliding_window=4,
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
@ -285,7 +281,6 @@ def test_get_num_blocks_to_allocate():
head_size=1,
dtype=torch.float32,
sliding_window=4, # Placeholder value, not related to test result
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
@ -308,7 +303,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
head_size=1,
dtype=torch.float32,
attention_chunk_size=4, # Placeholder value, not related to test result
use_mla=False,
)
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)

View File

@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform
MTP_SIMILARITY_RATE = 0.8
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
@ -222,3 +224,66 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
ids=["mimo", "deepseek"])
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
method, model_name, tp_size = model_setup
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"num_speculative_tokens": 1,
"max_model_len": 2048,
},
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 80% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

View File

@ -836,8 +836,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
mock_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
dtype=torch.float16)
mock_executor.get_kv_cache_specs.return_value = [{
"default": mock_spec

View File

@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size
slot_size_bytes = 4096
self.slot_size_per_layer = [slot_size_bytes]
self.block_len_per_layer = [slot_size_bytes * self.block_size]
self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=self.block_len,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here.
@ -485,8 +486,8 @@ class TestNixlHandshake:
worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent
worker.slot_size_bytes = 4096
worker.block_len = worker.slot_size_bytes * worker.block_size
worker.slot_size_per_layer = [4096]
worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
@ -498,7 +499,7 @@ class TestNixlHandshake:
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
block_len=worker.block_len,
block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout,
)

View File

@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
"target_attn_1": mock.MagicMock(),
"target_attn_2": mock.MagicMock()
}
target_indx_layers: dict[str, mock.MagicMock] = {}
# Draft model has one extra attention layer compared to target model
all_attn_layers = {
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
}
all_indx_layers: dict[str, mock.MagicMock] = {}
# Make mock_get_layers return different values for each call
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
mock_get_layers.side_effect = [
target_attn_layers, target_indx_layers, all_attn_layers,
all_indx_layers
]
# Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock()
@ -658,6 +664,9 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building.
proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock(

View File

@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(model=mimo_7b_dir,
runner="generate",
max_model_len=100,
trust_remote_code=True)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=mimo_7b_dir,
method="mtp",
num_speculative_tokens=num_speculative_tokens,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""
# Setup mocks
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model
target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
target_indexer_layers: dict = {}
all_indexer_layers: dict = {}
mock_get_layers.side_effect = [
target_attn_layers, target_indexer_layers, all_attn_layers,
all_indexer_layers
]
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
mock_get_pp_group.return_value = mock_pp_group
# Create target model
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
target_model.lm_head = mock.MagicMock()
# Create MTP proposer
proposer = _create_mtp_proposer(num_speculative_tokens=4)
proposer.load_model(target_model)
# Verify MTP-specific behavior:
# Model is loaded
mock_get_model.assert_called_once()
# MTP shares lm_head with target model
assert proposer.model.lm_head == target_model.lm_head
# MTP shares embed_tokens with target model
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
@pytest.mark.parametrize("num_speculative_tokens", [1])
def test_mtp_propose(num_speculative_tokens, monkeypatch):
"""Test that MTP's forward method returns hidden states directly"""
device = torch.device(current_platform.device_type)
batch_size = 2
seq_lens = [5, 3]
total_tokens = sum(seq_lens)
vocab_size = 100
proposer = _create_mtp_proposer(num_speculative_tokens)
hidden_size = proposer.hidden_size
# Mock the MTP model to verify it returns hidden states directly
model_mock = mock.MagicMock()
# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
model_mock.side_effect = forward_returns
# Mock compute_logits
def create_deterministic_logits(batch_size, vocab_size, token_offset):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
logits[:, token_offset] = 100.0
return logits
if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
for i in range(num_speculative_tokens)
]
model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()
# Setup attention metadata
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder
# Run propose
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
# Verify the model was called correctly
assert model_mock.called
# Verify output shape
assert result.shape == (batch_size, num_speculative_tokens)

View File

@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner):
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
)
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(

View File

@ -1678,6 +1678,15 @@ def cp_gather_cache(src_cache: torch.Tensor,
cu_seq_lens, batch_size, seq_starts)
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
kv_cache_dtype: str) -> None:
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
quant_block_size,
kv_cache_dtype)
def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)

View File

@ -70,6 +70,7 @@ class AttentionBackend(ABC):
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
raise NotImplementedError

View File

@ -95,6 +95,7 @@ class Attention(nn.Module, AttentionLayerBase):
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sparse: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
@ -155,6 +156,7 @@ class Attention(nn.Module, AttentionLayerBase):
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.use_sparse = use_sparse
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
@ -187,7 +189,8 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype,
block_size,
use_mla=use_mla,
has_sink=self.has_sink)
has_sink=self.has_sink,
use_sparse=use_sparse)
else:
self.attn_backend = attn_backend

View File

@ -138,3 +138,208 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out = cp_group.reduce_scatter(out, dim=1)
return out
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
out_ptr, # [B, Lmax, D]
lengths_ptr, # *i32, [B]
N: tl.constexpr,
D: tl.constexpr,
Lmax: tl.constexpr,
PAD_VALUE: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# Compute start index and sequence length from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
# compute input row indices for valid (b, t)
in_row = in_start + off_t
valid_row = (off_t < seq_len) & t_mask
# Pointers
# x_ptr: row-major [N, D]
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
# out_ptr: row-major [B, Lmax, D]
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
None] * D + off_d[None, :]
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
d_mask = off_d[None, :] < D
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
# Load & write only where within seq_len
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
def pack_seq_triton(x: torch.Tensor,
lengths: torch.Tensor,
pad_value: float = -float('inf'),
block_t: int = 64,
block_d: int = 64) -> torch.Tensor:
"""
Pack sequences of different lengths into a batched tensor.
Args:
x: [N, ...] - input tensor where N is total number of tokens
lengths: [B] - sequence lengths for each batch
pad_value: value to use for padding
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
packed: [B, Lmax, ...] - packed tensor
"""
# Handle multi-dimensional input by reshaping to (N, -1)
original_shape = x.shape
if len(original_shape) > 2:
N = original_shape[0]
x_reshaped = x.reshape(N, -1)
D = x_reshaped.shape[1]
else:
N, D = x.shape
x_reshaped = x
B = lengths.numel()
Lmax = int(lengths.max().item())
# Starts are computed inside the kernel from lengths
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_pack_seq_kernel[grid](x_reshaped,
out,
lengths.int(),
N,
D,
Lmax,
PAD_VALUE=float(pad_value),
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 2:
output_shape = (B, Lmax) + original_shape[1:]
out = out.reshape(output_shape)
return out
@triton.jit
def _unpack_seq_triton_kernel(
packed_ptr, # [B, Lmax, D]
out_ptr, # [N, D]
lengths_ptr, # *i32, [B]
B: tl.constexpr,
Lmax: tl.constexpr,
D: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# bounds: compute start from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
valid_row = (off_t < seq_len) & t_mask
# compute output row indices for valid (b, t)
out_row = in_start + off_t
# Pointers
# packed_ptr: row-major [B, Lmax, D]
packed_row_ptr = packed_ptr + (pid_b * Lmax +
off_t)[:, None] * D + off_d[None, :]
# out_ptr: row-major [N, D]
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
# Load from packed tensor and store to output
d_mask = off_d[None, :] < D
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
def unpack_seq_triton(packed_tensor: torch.Tensor,
lengths: torch.Tensor,
block_t: int = 64,
block_d: int = 64) -> torch.Tensor:
"""
Unpack a packed decode query tensor back to the original format.
Efficient Triton implementation.
Args:
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
lengths: [B] - sequence lengths for each batch
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
unpacked_tensor: [N, ...] where N = sum(lengths)
"""
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
original_shape = packed_tensor.shape
if len(original_shape) > 3:
B, Lmax = original_shape[:2]
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
D = packed_reshaped.shape[2]
else:
B, Lmax, D = packed_tensor.shape
packed_reshaped = packed_tensor
# Calculate total number of elements
N = int(lengths.sum().item())
out = torch.empty((N, D),
device=packed_tensor.device,
dtype=packed_tensor.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_unpack_seq_triton_kernel[grid](packed_reshaped,
out,
lengths.int(),
B,
Lmax,
D,
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 3:
output_shape = (N, ) + original_shape[2:]
out = out.reshape(output_shape)
return out

View File

@ -19,6 +19,15 @@ if current_platform.is_cuda():
else:
_flashmla_C_AVAILABLE = False
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
"""
@ -37,24 +46,34 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
- cache_seqlens: (batch_size), dtype torch.int32.
- num_q_tokens_per_head_k:
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Return:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
Returns:
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
num_heads_per_head_k,
num_heads_k)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
is_fp8_kvcache, topk)
def flash_mla_with_kvcache(
@ -69,45 +88,95 @@ def flash_mla_with_kvcache(
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
torch.int32, return by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q.
descale_k: (batch_size), torch.float32. Descaling factors for K.
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head dimension of v.
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
returned by get_mla_metadata.
- num_splits:
(batch_size + 1), torch.int32, returned by get_mla_metadata.
- softmax_scale: float.
The scale of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
Returns:
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k,
)
if indices is not None:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert not causal, \
"causal must be `false` if sparse attention is enabled."
assert (descale_q is None) == (
descale_k is None
), "descale_q and descale_k should be both None or both not None"
# Note(hc): need revisit when we support DCP with decode query_len > 1.
return out.squeeze(1), softmax_lse.squeeze(-1)
if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
indices)
return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v)
return results
#

View File

@ -50,6 +50,7 @@ class PagedAttention:
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)

View File

@ -144,6 +144,7 @@ def get_attn_backend(
block_size: int,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@ -158,6 +159,7 @@ def get_attn_backend(
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
has_sink=has_sink,
use_sparse=use_sparse,
)
@ -170,6 +172,7 @@ def _cached_get_attn_backend(
use_v1: bool = False,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
@ -203,7 +206,7 @@ def _cached_get_attn_backend(
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
use_mla, has_sink)
use_mla, has_sink, use_sparse)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")

View File

@ -22,7 +22,8 @@ else:
logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2",
"fp8_inc"]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
@ -52,7 +53,11 @@ class CacheConfig:
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
bfloat16 instead, this is an invalid option for models that do not default
to fp8.
"""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
@ -171,11 +176,12 @@ class CacheConfig:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in get_args(CacheDType):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor.")
if self.cache_dtype.startswith("fp8"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

View File

@ -360,6 +360,7 @@ class CompilationConfig:
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
"vllm.sparse_attn_indexer",
]
def compute_hash(self) -> str:

View File

@ -137,6 +137,9 @@ class ModelConfig:
"""Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only
be enabled in trusted environments."""
allowed_media_domains: Optional[list[str]] = None
"""If set, only media URLs that belong to this domain can be used for
multi-modal inputs. """
revision: Optional[str] = None
"""The specific model version to use. It can be a branch name, a tag name,
or a commit id. If unspecified, will use the default version."""
@ -1074,14 +1077,14 @@ class ModelConfig:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp',
'kimi_k2', 'longcat_flash'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
# underlying architecture
return self.hf_text_config.model.model_type in \
('deepseek_v2', 'deepseek_v3') \
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
and self.hf_text_config.kv_lora_rank is not None
return False

View File

@ -279,6 +279,24 @@ class ParallelConfig:
assert last_exc is not None
raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod
def has_unfinished_dp(dp_group: ProcessGroup,
has_unfinished: bool) -> bool:

View File

@ -32,14 +32,17 @@ logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"]
"longcat_flash_mtp", "mtp"]
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
@config
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
enforce_eager: Optional[bool] = None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the
@ -143,7 +146,7 @@ class SpeculativeConfig:
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3":
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
@ -207,11 +210,21 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.method in MTP_MODEL_TYPES:
logger.warning("method `%s` is deprecated and replaced with mtp.",
self.method)
self.method = "mtp"
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
if (self.target_model_config
and self.target_model_config.hf_text_config.model_type
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
if self.method == "mtp":
assert (
self.target_model_config
is not None), "target_model_config must be present for mtp"
if self.target_model_config.hf_text_config.model_type \
== "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# remove this when the issue is fixed.
self.enforce_eager = True
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
@ -281,6 +294,8 @@ class SpeculativeConfig:
trust_remote_code,
allowed_local_media_path=self.target_model_config.
allowed_local_media_path,
allowed_media_domains=self.target_model_config.
allowed_media_domains,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
@ -312,31 +327,13 @@ class SpeculativeConfig:
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
in MTP_MODEL_TYPES):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
"Enabling num_speculative_tokens > 1 will run" \
"multiple times of forward on same MTP layer" \
",which may result in lower acceptance rate" \
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
@ -353,7 +350,7 @@ class SpeculativeConfig:
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
"eagle, or mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
@ -562,8 +559,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens
def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
return self.method in ("eagle", "eagle3", "mtp")
def __repr__(self) -> str:
method = self.method

View File

@ -6,7 +6,7 @@ import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device,
dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
rank = self.rank if is_sequence_parallel else self.dp_rank
world_size = (self.world_size
if is_sequence_parallel else self.dp_world_size)
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx)
for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_sp_cpu[idx]
get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states
def destroy(self):
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group):
super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs):
raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def dispatch(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError
def destroy(self):
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.workspace_tensor = None
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False
self.initialized = False

View File

@ -28,6 +28,8 @@ class Cache:
class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group):
self.cpu_group = cpu_group
@ -40,6 +42,7 @@ class All2AllManagerBase:
# all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group()
self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
@ -60,17 +63,21 @@ class All2AllManagerBase:
# and reuse it for the same config.
raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def set_num_sms(self, num_sms: int):
pass
def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def destroy(self):
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
module.quant_method.init_prepare_finalize(module)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.

View File

@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states

View File

@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
dist.broadcast(input_, src=src, group=self.device_group)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits)
hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states)
hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states

View File

@ -84,7 +84,7 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
block_len: int
block_lens: list[int]
attn_backend_name: str
kv_cache_layout: str
@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
def add_new_req(
self,
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
def get_num_new_matched_tokens(
self, request: "Request",
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
if not params:
return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=True,
save_to_host=False,
)
for req_id, (req, block_ids) in self._reqs_need_save.items():
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_need_send = {}
return meta
@ -465,8 +474,11 @@ class NixlConnectorWorker:
"backends", ["UCX"])
# Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 and nixl_agent_config is not None else None
if nixl_agent_config is None:
config = None
else:
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 else nixl_agent_config(num_threads=8)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
@ -546,6 +558,8 @@ class NixlConnectorWorker:
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@ -752,6 +766,9 @@ class NixlConnectorWorker:
split_k_and_v = not (self.use_mla or self._use_pallas
or self._use_flashinfer)
tensor_size_bytes = None
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
@ -769,10 +786,25 @@ class NixlConnectorWorker:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
assert cache.shape[0] == self.num_blocks, \
"All kv cache tensors must have the same number of blocks"
self.block_len_per_layer.append(curr_tensor_size_bytes //
self.num_blocks)
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
self.block_size)
if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
logger.debug("Different block lengths collected: %s",
set(self.block_len_per_layer))
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data)
@ -785,16 +817,12 @@ class NixlConnectorWorker:
logger.debug("Done registering descs")
self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2
# NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in
@ -804,17 +832,17 @@ class NixlConnectorWorker:
# of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2
kv_block_len = self.get_backend_aware_kv_block_len()
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in seen_base_addresses:
for i, base_addr in enumerate(seen_base_addresses):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank))
@ -824,7 +852,7 @@ class NixlConnectorWorker:
# descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset
# Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len
@ -864,7 +892,7 @@ class NixlConnectorWorker:
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
block_len=self.block_len,
block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout)
ready_event = threading.Event()
@ -889,7 +917,7 @@ class NixlConnectorWorker:
The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker.
Here's an example:
Here's an example (non-MLA case):
rank_offset p_remote_tp_rank
(kv split no)
@ -945,14 +973,20 @@ class NixlConnectorWorker:
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or is_kv_replicated:
# With MLA the only difference is in the number of blocks.
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes)
assert self.block_len == nixl_agent_meta.block_len
# With replicated KV cache, only the number of blocks can differ.
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
"KV cache sizes must match between P and D when replicated"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0])
else:
remote_block_size = nixl_agent_meta.block_len // (
self.slot_size_bytes * tp_ratio)
# When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, \
"All remote layers must have the same block size"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio)
if self._use_flashinfer:
# With flashinfer, KV are sent in the same message.
remote_block_size //= 2
@ -963,14 +997,14 @@ class NixlConnectorWorker:
raise ValueError(
"Heterogeneous TP is not supported on XPU")
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported "
f"{self.block_size=} {remote_block_size=}")
"Remote P worker with different page/block size is not supported "
f"{self.block_size=}, {remote_block_size=}")
# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
@ -985,13 +1019,16 @@ class NixlConnectorWorker:
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
kv_block_len = self.get_backend_aware_kv_block_len()
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
assert len(nixl_agent_meta.kv_caches_base_addr) == len(
self.block_len_per_layer)
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
block_offset = block_id * nixl_agent_meta.block_lens[i]
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
@ -1002,9 +1039,9 @@ class NixlConnectorWorker:
if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
logger.debug(
@ -1082,6 +1119,7 @@ class NixlConnectorWorker:
"Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.", req_id,
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
self._reqs_to_process.remove(req_id)
del self._reqs_to_send[req_id]
done_sending.add(req_id)
@ -1097,7 +1135,8 @@ class NixlConnectorWorker:
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
if req_id not in self._reqs_to_send:
if (req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process):
logger.error(
"Potentially invalid KV blocks for "
"unrecognized request %s were retrieved by "
@ -1110,7 +1149,8 @@ class NixlConnectorWorker:
tp_ratio):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id]
self._reqs_to_process.remove(req_id)
self._reqs_to_send.pop(req_id, None)
return notified_req_ids
def _pop_done_transfers(
@ -1171,8 +1211,19 @@ class NixlConnectorWorker:
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())
# Keep around the requests that have been part of a batch. This is
# needed because async scheduling pushes the misalignment between the
# moment in which requests expiration is set (P side) and the moment in
# which blocks are read from D. As P can now more easily lag behind D
# while processing the next batch, we make sure to only set an
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)
# Add to requests that are waiting to be read and track expiration.
self._reqs_to_send.update(metadata.reqs_to_send)
for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
@ -1317,7 +1368,7 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
def get_backend_aware_kv_block_len(self):
def get_backend_aware_kv_block_len(self, layer_idx: int):
"""
Get the block length for one K/V element (K and V have the same size).
@ -1328,9 +1379,9 @@ class NixlConnectorWorker:
"""
if self._use_flashinfer:
# For indexing only half (either just the K or V part).
block_len = self.block_len // 2
block_len = self.block_len_per_layer[layer_idx] // 2
else:
block_len = self.block_len
block_len = self.block_len_per_layer[layer_idx]
return block_len
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:

View File

@ -871,17 +871,24 @@ class GroupCoordinator:
model)
def dispatch(
self, hidden_states: torch.Tensor,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states,
router_logits)
router_logits,
is_sequence_parallel)
else:
return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor:
def combine(self,
hidden_states,
is_sequence_parallel: bool = False) -> torch.Tensor:
if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states)
return self.device_communicator.combine(hidden_states,
is_sequence_parallel)
else:
return hidden_states

View File

@ -297,6 +297,8 @@ class EngineArgs:
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
allowed_media_domains: Optional[
list[str]] = ModelConfig.allowed_media_domains
download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format
@ -531,6 +533,8 @@ class EngineArgs:
**model_kwargs["hf_config_path"])
model_group.add_argument("--allowed-local-media-path",
**model_kwargs["allowed_local_media_path"])
model_group.add_argument("--allowed-media-domains",
**model_kwargs["allowed_media_domains"])
model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision",
**model_kwargs["code_revision"])
@ -997,6 +1001,7 @@ class EngineArgs:
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
allowed_media_domains=self.allowed_media_domains,
dtype=self.dtype,
seed=self.seed,
revision=self.revision,
@ -1481,7 +1486,7 @@ class EngineArgs:
raise NotImplementedError(
"Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")
"such as ngram, medusa, eagle, or mtp.")
V1_BACKENDS = [
"FLASH_ATTN",

View File

@ -11,7 +11,12 @@ from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast)
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
# yapf: disable
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils import random_uuid, supports_kw
logger = init_logger(__name__)
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@property
def allowed_media_domains(self):
return self._model_config.allowed_media_domains
@property
def mm_registry(self):
return MULTIMODAL_REGISTRY
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
)
def parse_image(
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
)
def parse_image(
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
)
try:
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
**resolved_kwargs,
)
# External library exceptions can sometimes occur despite the framework's

View File

@ -86,6 +86,8 @@ class LLM:
or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
@ -169,6 +171,7 @@ class LLM:
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
tensor_parallel_size: int = 1,
dtype: ModelDType = "auto",
quantization: Optional[QuantizationMethods] = None,
@ -264,6 +267,7 @@ class LLM:
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path,
allowed_media_domains=allowed_media_domains,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,

View File

@ -3,12 +3,14 @@
import asyncio
import gc
import hashlib
import importlib
import inspect
import json
import multiprocessing
import multiprocessing.forkserver as forkserver
import os
import secrets
import signal
import socket
import tempfile
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
class AuthenticationMiddleware:
"""
Pure ASGI middleware that authenticates each request by checking
if the Authorization header exists and equals "Bearer {api_key}".
if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes
-----
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app
self.api_tokens = {f"Bearer {token}" for token in tokens}
self.api_tokens = [
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]:
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope)
# Type narrow to satisfy mypy.
if url_path.startswith("/v1") and headers.get(
"Authorization") not in self.api_tokens:
if url_path.startswith("/v1") and not self.verify_token(headers):
response = JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return response(scope, receive, send)
@ -1696,6 +1716,7 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.

View File

@ -103,9 +103,13 @@ class FrontendArgs:
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to OpenAI
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None

View File

@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "",
enable_auto_tools: bool = False,
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
self.response_role = response_role
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs
# set up tool use
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
if not self.use_harmony:
# Common case.
request_chat_template = request.chat_template
chat_template_kwargs = request.chat_template_kwargs
if not self.trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs and
chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
(
conversation,
request_prompts,
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template=request_chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,

View File

@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int) -> list[int]:
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
sequence_parallel_size)
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
return sp_tokens.tolist()
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int,
max_num_tokens: int,
chunk_idx: int) -> list[int]:
dp_size = len(num_tokens_across_dp_cpu)
local_size = [-1] * dp_size
for i in range(dp_size):
dp_tokens = num_tokens_across_dp_cpu[i]
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
sequence_parallel_size)
sp_size = len(sp_tokens)
local_size = [-1] * sp_size
for i in range(sp_size):
# Take into account sharding if MoE activation is sequence parallel.
local_size[i] = min(max_num_tokens,
dp_tokens - (max_num_tokens * chunk_idx))
sp_tokens[i] - (max_num_tokens * chunk_idx))
if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done
return local_size
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
@dataclass
class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor
num_tokens_across_dp_cpu: torch.Tensor
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: Optional[list[int]] = None
@staticmethod
@ -98,6 +113,17 @@ class DPMetadata:
dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu()
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cummulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
num_tokens_across_sp_cpu = (
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@staticmethod
def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
@ -147,10 +173,10 @@ class DPMetadata:
@staticmethod
def make(
parallel_config: ParallelConfig,
attn_metadata: Any,
num_tokens: int,
num_tokens_across_dp: Optional[torch.Tensor] = None
parallel_config: ParallelConfig,
attn_metadata: Any,
num_tokens: int,
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1
@ -167,18 +193,18 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
if num_tokens_across_dp is None:
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
assert (num_tokens_across_dp_cpu is None
or num_tokens_across_dp_cpu[dp_rank] == batchsize
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
if num_tokens_across_dp_cpu is None:
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
num_tokens_across_dp)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
@contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
def chunked_sizes(self, sequence_parallel_size: int,
max_chunk_size_per_rank: int, chunk_idx: int):
"""
Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution.
@ -192,31 +218,40 @@ class DPMetadata:
`chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank.
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
to determine the chunk-wise split.
`self.local_sizes` is only valid inside the context.
Args:
sequence_parallel_size: When Attn is TP and MoE layers are EP,
we use SP between the layers to avoid
redundant ops. We need this value to
compute the chunked sizes.
max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for.
"""
cu_sizes = self.cu_tokens_across_dp_cpu
num_tokens_across_dp_cpu = [
(cu_sizes[i] -
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
for i in range(len(cu_sizes))
]
self.local_sizes = _compute_chunked_local_num_tokens(
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
self.num_tokens_across_dp_cpu, sequence_parallel_size,
max_chunk_size_per_rank, chunk_idx)
try:
yield self.local_sizes
finally:
self.local_sizes = None
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
self.num_tokens_across_dp_cpu, sequence_parallel_size)
try:
yield self.local_sizes
finally:
self.local_sizes = None
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
assert self.local_sizes is not None
return self.local_sizes

View File

@ -3,6 +3,7 @@
from abc import abstractmethod
from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum
from typing import Callable, Literal, Optional, Union, get_args, overload
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_
self.sp_size = tp_size_ if is_sequence_parallel else 1
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
with ctx.dp_metadata.chunked_sizes(self.sp_size,
moe_dp_chunk_size_per_rank,
chunk_idx):
process_chunk(chunk_start,
chunk_end,
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
else:
shared_output = None
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
ctx = get_forward_context()
sp_ctx = ctx.dp_metadata.sp_local_sizes(
self.sp_size) if ctx.dp_metadata else nullcontext()
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
with sp_ctx:
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel)
if shared_output is not None:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
final_hidden_states = (
shared_output,
final_hidden_states,
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states)
if shared_output is not None:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
final_hidden_states = (
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
states = self.maybe_all_reduce_tensor_model_parallel(states)
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states,
self.is_sequence_parallel)
return states
if (not self.is_sequence_parallel and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)):
states = self.maybe_all_reduce_tensor_model_parallel(
states)
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
return states
if self.shared_experts is not None:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
@classmethod
def make_expert_params_mapping(

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
@ -375,3 +376,20 @@ class PolyNorm(CustomOp):
x: torch.Tensor,
) -> torch.Tensor:
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
class LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
self.eps).type_as(x)

View File

@ -24,6 +24,9 @@ class MLAModules:
q_a_layernorm: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
indexer: Optional[torch.nn.Module]
is_sparse: bool
topk_indices_buffer: Optional[torch.Tensor]
@CustomOp.register("multi_head_latent_attention")
@ -76,6 +79,13 @@ class MultiHeadLatentAttention(CustomOp):
self.kv_b_proj = mla_modules.kv_b_proj
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
@ -92,6 +102,7 @@ class MultiHeadLatentAttention(CustomOp):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
@ -100,6 +111,7 @@ class MultiHeadLatentAttention(CustomOp):
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
indexer=self.indexer,
)
self.prefix = prefix
@ -145,6 +157,10 @@ class MultiHeadLatentAttention(CustomOp):
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
attn_out = self.mla_attn(
q,
kv_c_normed,

View File

@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale
# at the same time.
if is_deep_gemm_e8m0_used():
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight)
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(layer.weight.data,
layer.weight_scale.data, block_sz)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (current_platform.is_device_capability(90)
and cutlass_block_fp8_supported
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
layer.weight)):
and cutlass_block_fp8_supported and not should_use_deepgemm):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False)

View File

@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
from vllm.config import QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -298,14 +298,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
Experts (MoE) Layer.
"""
def __init__(
self,
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config, prefix)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.mlp = AriaTextMoELayer(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")

View File

@ -346,8 +346,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
use_mla=model_config.use_mla).page_size_bytes
dtype=kv_cache_dtype).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
@ -401,6 +400,31 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
"exactly equal.", mamba_padding_pct)
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config = vllm_config.model_config.hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32 = hasattr(hf_config, "index_topk")
assert is_v32
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
# "auto")
cache_config = vllm_config.cache_config
if cache_config.cache_dtype == "auto" or \
cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
@ -417,4 +441,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
}

View File

@ -53,8 +53,20 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
def forward(
self,

View File

@ -32,17 +32,22 @@ import torch
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
@ -50,20 +55,35 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
class DeepseekV2MLP(nn.Module):
@ -108,43 +128,6 @@ class DeepseekV2MLP(nn.Module):
return x
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))
chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
fake_impl=sequence_parallel_chunk_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class DeepseekV2MoE(nn.Module):
def __init__(
@ -166,20 +149,7 @@ class DeepseekV2MoE(nn.Module):
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
@ -278,8 +248,7 @@ class DeepseekV2MoE(nn.Module):
# TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here.
if self.is_sequence_parallel:
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
hidden_states)
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
@ -328,6 +297,7 @@ class DeepseekV2Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
num_heads: int,
@ -341,6 +311,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
topk_indices_buffer: Optional[torch.Tensor] = None,
prefix: str = "",
) -> None:
super().__init__()
@ -358,6 +329,8 @@ class DeepseekV2Attention(nn.Module):
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
assert topk_indices_buffer is None, "topk_indices_buffer is not \
supported for DeepseekV2Attention"
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
@ -470,6 +443,391 @@ class DeepseekV2Attention(nn.Module):
return output
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
cache_config: CacheConfig):
super().__init__()
self.kv_cache = [torch.tensor([])]
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config
self.dtype = dtype
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def get_kv_cache_spec(self) -> KVCacheSpec:
return MLAAttentionSpec( # Only has one vector instead of K + V
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
)
def forward(self):
...
def get_attn_backend(self) -> AttentionBackend:
return DeepseekV32IndexerBackend
@torch.inference_mode()
def cp_gather_indexer_k_quant_cache(
kv_cache, # [num_blocks, block_size, head_dim + 1]
dst_value, # [cu_seq_lens[-1], head_dim]
dst_scale, # [cu_seq_lens[-1], 4]
block_table, # [batch_size, num_blocks]
cu_seq_lens, # [batch_size + 1, ]
batch_size,
):
num_blocks, block_size, _ = kv_cache.shape
head_dim = dst_value.shape[-1]
kv_cache = kv_cache.view(num_blocks, -1)
expected_value = []
expected_scale = []
for b in range(batch_size):
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
if s == 0:
continue
tot = cdiv(s, block_size)
blocks = block_table[b, :tot]
value = []
scale = []
full_block = torch.arange(tot - 1,
device=kv_cache.device,
dtype=torch.int32)
non_remaining_value = kv_cache[blocks[full_block], :block_size *
head_dim].view(-1, head_dim)
non_remaining_scale = kv_cache[blocks[full_block],
block_size * head_dim:].view(-1, 4)
remaining = s - (tot - 1) * block_size
value = torch.cat([
non_remaining_value,
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
],
dim=0)
scale = torch.cat([
non_remaining_scale,
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
remaining * 4].view(-1, 4)
],
dim=0)
expected_value.append(value)
expected_scale.append(scale)
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
gather_value = gather_value.view(torch.float8_e4m3fn)
gather_scale = gather_scale.view(torch.float32)
dst_value.copy_(gather_value)
dst_scale.copy_(gather_scale)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
num_prefills = attn_metadata.num_prefills
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn)
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1],
device=k.device,
dtype=torch.float32)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
prefill_metadata.block_table,
prefill_metadata.cu_seq_lens,
num_prefills,
)
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
num_tokens = attn_metadata.num_actual_tokens
logits = fp8_mqa_logits(
q_fp8[num_decode_tokens:num_tokens],
(k_fp8, k_scale),
weights[num_decode_tokens:num_tokens],
cu_seqlen_ks,
cu_seqlen_ke,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:])
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
positions = torch.arange(max_model_len,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
row_indices = torch.arange(padded_num_tokens,
device=current_device) // next_n
next_n_offset = torch.arange(
padded_num_tokens,
device=padded_q_fp8_decode_tokens.device) % next_n
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
topk_indices = logits.topk(topk_tokens,
dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens)
topk_indices_buffer[:num_decode_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
return topk_indices_buffer
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device,
dtype=torch.uint8)
_k_fp8 = _flattened_kv[..., :head_dim].view(
torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
direct_register_custom_op(
op_name="sparse_attn_indexer",
op_func=sparse_attn_indexer,
mutates_args=["topk_indices_buffer"],
fake_impl=sparse_attn_indexer_fake,
dispatch_key=current_platform.dispatch_key,
)
class Indexer(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
q_lora_rank: int,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
topk_indices_buffer: Optional[torch.Tensor],
prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.config = config
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64
self.head_dim = config.index_head_dim # 128
self.rope_dim = config.qk_rope_head_dim # 64
self.q_lora_rank = q_lora_rank # 1536
# no tensor parallel, just replicated
self.wq_b = ReplicatedLinear(self.q_lora_rank,
self.head_dim * self.n_head,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b")
self.wk = ReplicatedLinear(hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk")
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(hidden_size,
self.n_head,
quant_config=None,
prefix=f"{prefix}.weights_proj")
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"
self.quant_block_size = 128 # TODO: get from config
self.topk_indices_buffer = topk_indices_buffer
# NOTE: (zyongye) we use fp8 naive cache,
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim +
self.head_dim // self.quant_block_size * 4,
dtype=torch.uint8,
prefix=f"{prefix}.k_cache",
cache_config=cache_config)
self.max_model_len = vllm_config.model_config.max_model_len
self.prefix = prefix
from vllm.v1.attention.backends.mla.indexer import (
get_max_prefill_buffer_size)
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
rotary_emb) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
k, _ = self.wk(hidden_states)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt
is not None)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = weights.unsqueeze(
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
class DeepseekV2MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
@ -481,6 +839,7 @@ class DeepseekV2MLAAttention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
num_heads: int,
@ -495,6 +854,7 @@ class DeepseekV2MLAAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
topk_indices_buffer: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@ -575,6 +935,15 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
self.indexer = Indexer(vllm_config, config, hidden_size,
q_lora_rank, quant_config, cache_config,
topk_indices_buffer, f"{prefix}.indexer")
else:
self.indexer = None
mla_modules = MLAModules(
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
@ -588,7 +957,11 @@ class DeepseekV2MLAAttention(nn.Module):
if self.q_lora_rank is not None else None,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else None,
indexer=self.indexer,
is_sparse=self.is_v32,
topk_indices_buffer=topk_indices_buffer,
)
self.mla_attn = MultiHeadLatentAttention(
self.hidden_size,
self.num_local_heads,
@ -614,7 +987,10 @@ class DeepseekV2MLAAttention(nn.Module):
class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
@ -637,6 +1013,7 @@ class DeepseekV2DecoderLayer(nn.Module):
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
vllm_config=vllm_config,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@ -652,6 +1029,7 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
topk_indices_buffer=topk_indices_buffer,
)
if (config.n_routed_experts is not None
@ -735,6 +1113,16 @@ class DeepseekV2Model(nn.Module):
self.config = config
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
@ -747,7 +1135,8 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:

View File

@ -29,10 +29,9 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
vllm_config: VllmConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_emb_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
prefix)
self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
def forward(
self,
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict({
str(idx):
ErnieMultiTokenPredictorLayer(
config,
vllm_config,
f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
)
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)

View File

@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
class Glm4DecoderLayer(nn.Module):
def __init__(
self,
config: Glm4Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[Glm4Config] = None) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)

View File

@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
vllm_config: VllmConfig,
layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False,
has_bias=True,
activation="swigluoai")
activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
x = x[:num_tokens]
return x
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
def __init__(
self,
config: GptOssConfig,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config,
prefix=f"{prefix}.attn",
cache_config=cache_config)
self.mlp = MLPBlock(config,
self.mlp = MLPBlock(vllm_config,
self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding(
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers,
lambda prefix: TransformerBlock(
self.config,
cache_config=self.cache_config,
quant_config=self.quant_config,
vllm_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",

View File

@ -29,12 +29,13 @@ from typing import Any, Optional
import torch
from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
is_sequence_parallel=False,
prefix: str = ""):
super().__init__()
self.hidden_size = hidden_size
self.is_sequence_parallel = is_sequence_parallel
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts")
prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
num_tokens = orig_shape[0]
final_hidden_states = final_hidden_states[:num_tokens]
return final_hidden_states.view(orig_shape)
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size,
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
bias: bool = False,
prefix: str = "",
reduce_results: bool = True,
disable_tp: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[LlamaConfig] = None) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:

View File

@ -28,7 +28,8 @@ from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32))
def __init__(self,
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size,
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
bias=False,
prefix=f"{prefix}.shared_expert",
reduce_results=False,
disable_tp=self.is_sequence_parallel,
)
self.experts = SharedFusedMoE(
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
renormalize=False,
quant_config=quant_config,
prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel,
)
def forward(self, hidden_states):
num_tokens = hidden_states.shape[0]
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
router_logits, _ = self.router(hidden_states)
shared_out, routed_out = self.experts(
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
)
experts_out = routed_out + shared_out
if self.tp_size > 1:
if self.is_sequence_parallel:
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
experts_out = experts_out[:num_tokens]
elif self.tp_size > 1:
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out)
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
class Llama4DecoderLayer(nn.Module):
def __init__(
self,
config: Llama4TextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[Llama4TextConfig] = None) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
self.layer_idx + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer:
self.feed_forward = Llama4MoE(
config=config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=f"{prefix}.feed_forward",
)
else:

View File

@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
Llama4DecoderLayer(
self.config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
vllm_config: VllmConfig,
disable_input_layernorm: bool,
prefix: str = "",
config: Optional[LlamaConfig] = None,
) -> None:
super().__init__(config, prefix=prefix)
super().__init__(vllm_config, prefix=prefix, config=config)
# Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
LlamaDecoderLayer(
self.config,
vllm_config,
i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers)
])
self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -9,13 +9,11 @@ import torch.nn as nn
from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -29,17 +27,14 @@ logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
def __init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[LlamaConfig] = None) -> None:
super().__init__(vllm_config, prefix=prefix, config=config)
config = config or vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
# override qkv
self.self_attn.qkv_proj = QKVParallelLinear(
@ -127,9 +122,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([
LlamaDecoderLayer(
config=self.config,
cache_config=current_vllm_config.cache_config,
current_vllm_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
config=self.config,
)
])
if hasattr(self.config, "target_hidden_size"):

View File

@ -308,6 +308,7 @@ class FlashDecoderLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: FlashConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
@ -329,6 +330,7 @@ class FlashDecoderLayer(nn.Module):
# Dual attention structure
self.self_attn = nn.ModuleList([
DeepseekV2MLAAttention(
vllm_config=vllm_config,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@ -454,6 +456,7 @@ class FlashModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: FlashDecoderLayer(
vllm_config,
config,
cache_config=cache_config,
quant_config=quant_config,

View File

@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.attn_backend = attn_backend
self.use_upstream_fa = use_upstream_fa
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
}
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen2_5_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList([
Qwen2_5_VisionBlock(dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(
vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(depth)
Qwen2_5_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
])
self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@property
def dtype(self) -> torch.dtype:

View File

@ -79,7 +79,7 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__)
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
_MAX_FRAMES_PER_VIDEO = 14
# === Vision Inputs === #
@ -932,6 +932,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=1,
image_processor=image_processor,
)
return num_image_tokens
@ -956,6 +957,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=1,
image_processor=None,
)
return max_image_size
@ -969,10 +971,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None,
)
def _get_max_video_frames(self, max_tokens: int) -> int:
def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 1) -> int:
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
num_frames = start_num_frames
while True:
next_num_frames = num_frames + 1
@ -994,12 +998,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
self,
seq_len: int,
mm_counts: Mapping[str, int],
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
) -> int:
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len)
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO)
max_frames_per_video)
return max(max_frames_per_video, 1)

View File

@ -29,13 +29,13 @@ from typing import Any, Optional, Union
import torch
from torch import nn
from transformers import Qwen3MoeConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
assert hidden_states.dim(
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1
hidden_dim = hidden_states.shape[-1]
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
@ -361,11 +372,9 @@ class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.get_text_config()
cache_config = vllm_config.cache_config
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb),
lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -580,7 +586,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config

View File

@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import (
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states)
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
vllm_config: VllmConfig,
layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock(
config=config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = Qwen3NextMLP(
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )
self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
def get_layer(prefix: str):
return Qwen3NextDecoderLayer(
config,
vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix,
enable_eplb=enable_eplb,
)
self.start_layer, self.end_layer, self.layers = make_layers(

View File

@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
super().__init__()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
config,
vllm_config,
layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers))

View File

@ -33,11 +33,14 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
smart_resize as image_smart_resize)
from transformers.models.qwen3_vl import (Qwen3VLProcessor,
Qwen3VLVideoProcessor)
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLConfig, Qwen3VLVisionConfig)
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata
from vllm.attention.layer import check_upstream_fa_availability
@ -84,6 +87,9 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__)
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576
class Qwen3_VisionPatchEmbed(nn.Module):
@ -158,6 +164,8 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
@ -170,7 +178,9 @@ class Qwen3_VisionBlock(nn.Module):
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel)
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim,
act_fn=act_fn,
@ -287,19 +297,6 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(vision_config.depth)
])
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
@ -325,10 +322,34 @@ class Qwen3_VisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now.")
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa)
for layer_idx in range(vision_config.depth)
])
@property
def dtype(self) -> torch.dtype:
@ -569,11 +590,16 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height: int,
num_frames: int = 2,
do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessorFast],
image_processor: Optional[Union[Qwen2VLImageProcessorFast,
Qwen3VLVideoProcessor]],
) -> tuple[ImageSize, int]:
if image_processor is None:
if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
elif image_processor is None:
image_processor = self.get_image_processor()
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
@ -581,12 +607,22 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
temporal_patch_size = vision_config.temporal_patch_size
if do_resize:
if is_video:
smart_resize = video_smart_resize
extra_kwargs = {
"num_frames": num_frames,
"temporal_factor": temporal_patch_size
}
else:
smart_resize = image_smart_resize
extra_kwargs = {}
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"],
**extra_kwargs,
)
preprocessed_size = ImageSize(width=resized_width,
height=resized_height)
@ -605,6 +641,39 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
return preprocessed_size, num_vision_tokens
def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 2) -> int:
return super()._get_max_video_frames(max_tokens,
start_num_frames=start_num_frames)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
return super().get_num_frames_with_most_features(
seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO)
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
video_fps: float, merge_size: int):
if not isinstance(indices, list):
@ -674,6 +743,12 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
self.info.get_image_size_with_most_features())
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts)
target_video_size, _ = self.info._get_vision_info(
image_width=target_width,
image_height=target_height,
num_frames=target_num_frames,
image_processor=self.info.get_video_processor(),
)
return {
"image":
self._get_dummy_images(width=target_width,
@ -681,8 +756,8 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
width=target_video_size.width,
height=target_video_size.height,
num_frames=target_num_frames,
num_videos=num_videos,
),
@ -1051,14 +1126,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
@ -1074,11 +1152,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
@ -1513,7 +1595,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -212,6 +212,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1,
-2) # no bias
@ -230,8 +232,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
name_mapped, params_dict, loaded_weight,
shard_id, num_experts)
else:
if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra parameters for GPTQ/modelopt models
if name_mapped.endswith(
ignore_suffixes
@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
if not multimodal_config.get_limit_per_prompt("image") and \
not multimodal_config.get_limit_per_prompt("video"):
self.visual = None
else:
self.visual = Qwen3_VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix(
@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0
# register buffer for deepstack
self.deepstack_input_embeds = [
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
] if self.use_deepstack else None
if self.use_deepstack and self.visual is not None:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level

View File

@ -70,6 +70,7 @@ _TEXT_GENERATION_MODELS = {
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),

View File

@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available)
logger = init_logger(__name__)
@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size
text_config = hf_config.get_text_config()
return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -50,6 +50,7 @@ class MediaConnector:
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
) -> None:
"""
Args:
@ -82,6 +83,9 @@ class MediaConnector:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
def _load_data_url(
self,
@ -115,6 +119,14 @@ class MediaConnector:
return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
if self.allowed_media_domains and url_spec.hostname not in \
self.allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}")
def load_from_url(
self,
url: str,
@ -125,6 +137,8 @@ class MediaConnector:
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout)
@ -150,6 +164,8 @@ class MediaConnector:
loop = asyncio.get_running_loop()
if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
future = loop.run_in_executor(global_thread_pool,

View File

@ -93,11 +93,14 @@ class CpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str:
has_sink: bool, use_sparse: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on CPU.")
logger.info("Using Torch SDPA backend.")
if not use_v1:
raise ValueError("CPU backend only supports V1.")

View File

@ -129,6 +129,8 @@ class CudaPlatformBase(Platform):
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
if model_config is not None and model_config.use_mla:
use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
@ -175,6 +177,12 @@ class CudaPlatformBase(Platform):
"Forcing kv cache block size to 64 for FlashInferMLA "
"backend.")
# TODO(Chen): remove this hacky code
if use_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse "
"backend.")
# lazy import to avoid circular import
from vllm.config import CUDAGraphMode
@ -205,6 +213,12 @@ class CudaPlatformBase(Platform):
@classmethod
def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend:
# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
if cls.has_device_capability(100):
return _Backend.TORCH_SDPA
if dtype not in (torch.float16, torch.bfloat16):
return _Backend.XFORMERS
@ -225,7 +239,7 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
has_sink, use_sparse) -> str:
if use_mla:
if not use_v1:
raise RuntimeError(
@ -235,6 +249,11 @@ class CudaPlatformBase(Platform):
from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
if use_sparse:
logger.info_once("Using Sparse MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla.flashmla_sparse."
"FlashMLASparseBackend")
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100)
and block_size == 128)

View File

@ -194,7 +194,7 @@ class Platform:
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str:
has_sink: bool, use_sparse: bool) -> str:
"""Get the attention backend class of a device."""
return ""

View File

@ -195,7 +195,10 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str:
has_sink, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on ROCm.")
if use_mla:
if not use_v1:
raise RuntimeError(

View File

@ -49,7 +49,10 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink) -> str:
has_sink, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on TPU.")
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)

View File

@ -36,7 +36,10 @@ class XPUPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str:
has_sink: bool, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on XPU.")
use_v1 = envs.VLLM_USE_V1
if not use_v1:
raise ValueError("XPU backend only supports V1.")

View File

@ -66,6 +66,8 @@ class LazyConfigDict(dict):
_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
chatglm="ChatGLMConfig",
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v3="DeepseekV3Config",
deepseek_v32="DeepseekV3Config",
kimi_vl="KimiVLConfig",
Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)

View File

@ -8,6 +8,7 @@ Model configs may be defined in this directory for the following reasons:
"""
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.dotsocr import DotsOCRConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig
@ -37,6 +38,7 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
"ChatGLMConfig",
"DeepseekVLV2Config",
"DeepseekV3Config",
"DotsOCRConfig",
"EAGLEConfig",
"RWConfig",

View File

@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DeepseekV3Config(PretrainedConfig):
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size=2048,
num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts=1,
n_routed_experts=256,
ep_size=1,
routed_scaling_factor=2.5,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method='noaux_tc',
n_group=8,
topk_group=4,
num_experts_per_tok=8,
moe_layer_freq=1,
first_k_dense_replace=3,
norm_topk_prob=True,
scoring_func='sigmoid',
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@ -130,6 +130,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8,
}
TORCH_DTYPE_TO_NUMPY_DTYPE = {
@ -3433,6 +3434,12 @@ def has_triton_kernels() -> bool:
return _has_module("triton_kernels")
def has_tilelang() -> bool:
"""Whether the optional `tilelang` package is available."""
return _has_module("tilelang")
def set_process_title(name: str,
suffix: str = "",
prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None:

Some files were not shown because too many files have changed in this diff Show More