Compare commits

...

33 Commits

Author SHA1 Message Date
f71952c1c4 [Build/CI] Revert back to Ubuntu 20.04, install python 3.12 with uv (#26103)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 22:22:31 -07:00
d1007767c5 [Bugfix] Disable cascade attention with FlashInfer (#26130)
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 22:22:22 -07:00
c75c2e70d6 [Deepseek v3.2] Support indexer prefill chunking (#25999)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
9d9a2b77f1 [Small] Prevent bypassing media domain restriction via HTTP redirects (#26035)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
6040e0b6c0 [BugFix] Fix FI accuracy issue when used for MLA prefill (#26063)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
05bf0c52a1 Update base image to 22.04 (jammy) (#26065)
Signed-off-by: Huy Do <huydhn@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
c536881a7c [BugFix] ChunkedLocalAttention is currently not CG compatible (#26034)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
ebce361c07 [BugFix][DP/EP] Fix CUTLASS MLA hang under load (#26026)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:50 -07:00
e4beabd2c8 [BugFix] Fix default kv-cache-dtype default for DeepseekV3.2 (#25988)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:42 -07:00
febb688356 [Bugfix] Fix __syncwarp on ROCM (#25996)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:42 -07:00
a1825fe645 [MM] Add text-only mode for Qwen3-VL (#26000)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:42 -07:00
bab9231bf1 [Model] MTP fallback to eager for DeepSeek v32 (#25982)
Signed-off-by: Lu Fang <fanglu@fb.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:47:38 -07:00
c214d699fd [spec decode] Consolidate speculative decode method name for MTP (#25232)
Signed-off-by: zixi-qi <qizixi@meta.com>
2025-09-30 22:47:11 -07:00
c3dfb0f6dd [Bench] Add DeepSeekV32 to MoE benchmark (#25962)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
83f3c9beae [bugfix][deepseek] fix flashmla kernel selection (#25956)
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
d0b178cef1 [NIXL] Add support for MLA caches with different latent dim (#25902)
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
b3230e1ac0 [New Model] DeepSeek-V3.2 (Rebased to Main) (#25896)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com>
Co-authored-by: Lucia Fang <fanglu@meta.com>
Co-authored-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Siyuan Fu <siyuanf@nvidia.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Xiaozhu Meng <mxz297@gmail.com>
Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:24 -07:00
03df0fb5d2 [BugFix] Fix DP/EP hang (#25906)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:36:10 -07:00
9471879bd4 [Bug] Fix Weight Loading for Block FP8 Cutlass SM90 (#25909)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:32:47 -07:00
ab5b6459df [Bugfix] Fallback ViT attn backend to SDPA for blackwell (#25851)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-30 22:32:47 -07:00
8ce5d3198d [P/D] NIXL Updates (#25844)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Sage Moore <sage@neuralmagic.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-28 22:55:33 -07:00
09c2cbc04a [Bugfix] fix Qwen3VLMoe load when pp > 1 (#25838)
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com>
Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-28 22:55:17 -07:00
4c347044c9 [VLM] Update Qwen3-VL max_num_video_tokens calculation for configurable video profiling (#25557)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:12 -07:00
19e7ab7315 [Bugfix] Fix Qwen3-VL regression from #24982 (#25814)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
6de3d431d9 [MM] Optimize memory profiling for scattered multimodal embeddings (#25810)
Signed-off-by: Roger Wang <hey@rogerw.io>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
b14773bd64 [Bugfix][NIXL] Fix Async Scheduler timeout issue (#25808)
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:11 -07:00
26a7a33b88 [Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:35:03 -07:00
5aa5811a16 [CI] Fix FlashInfer AOT in release docker image (#25730)
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
c2fa2d4dc9 [Bugfix] Allow Only SDPA Backend for ViT on B200 for Qwen3-VL (#25788)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
32335c8b34 Add option to restrict media domains (#25783)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
04c2b26972 Add filtering for chat template kwargs (#25794)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
ee10d7e6ff Validate API tokens in constant time (#25781)
Signed-off-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: rentianyue-jk <rentianyue-jk@360shuke.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
bb79c4da2f Reduce the Cuda Graph memory footprint when running with DBO (#25779)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00
127 changed files with 5591 additions and 897 deletions

View File

@ -76,7 +76,7 @@ steps:
queue: arm64_cpu_queue_postmerge queue: arm64_cpu_queue_postmerge
commands: commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)" - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
# Add job to create multi-arch manifest # Add job to create multi-arch manifest

View File

@ -584,8 +584,9 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
elif config.architectures[0] in ( elif config.architectures[0] in (
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
"Glm4MoeForCausalLM", "Glm4MoeForCausalLM",
): ):
E = config.n_routed_experts E = config.n_routed_experts

View File

@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
else() else()
FetchContent_Declare( FetchContent_Declare(
flashmla flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. # The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with # Only build FlashMLA kernels if we are building for something compatible with
# sm90a # sm90a
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(SUPPORT_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
list(APPEND SUPPORT_ARCHS 9.0a)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
list(APPEND SUPPORT_ARCHS 10.0a)
endif()
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
if(FLASH_MLA_ARCHS)
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
set(FlashMLA_SOURCES set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp ${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu ${flashmla_SOURCE_DIR}/csrc/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu ${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu ${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) ${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
)
set(FlashMLA_Extension_SOURCES
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
)
set(FlashMLA_INCLUDES set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/cutlass/include ${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc) ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)
set(FlashMLA_Extension_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}" SRCS "${FlashMLA_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}") CUDA_ARCHS "${FLASH_MLA_ARCHS}")
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_Extension_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
define_gpu_extension_target( define_gpu_extension_target(
_flashmla_C _flashmla_C
DESTINATION vllm DESTINATION vllm
@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
define_gpu_extension_target(
_flashmla_extension_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${FlashMLA_Extension_SOURCES}
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
USE_SABI 3
WITH_SOABI)
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_extension_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
else() else()
# Create an empty target for setup.py when not targeting sm90a systems # Create empty targets for setup.py when not targeting sm90a systems
add_custom_target(_flashmla_C) add_custom_target(_flashmla_C)
add_custom_target(_flashmla_extension_C)
endif() endif()

View File

@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_page_table( load_page_table(
blk_coord, blk_coord,
problem_shape, problem_shape,
params.mainloop, params.mainloop,
shared_storage.tensors, shared_storage.tensors,
pipeline_page_table, pipeline_pt_producer_state, pipeline_page_table, pipeline_pt_producer_state,
local_split_kv local_split_kv
); );
} }
} }
@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_cpasync( load_cpasync(
blk_coord, blk_coord,
@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
params.mainloop_params, params.mainloop_params,
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv, local_split_kv,
/* must be shared pipe */ /* must be shared pipe */
pipeline_page_table, pipeline_pt_consumer_state pipeline_page_table, pipeline_pt_consumer_state
); );
@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_tma</* paged= */ true>( load_tma</* paged= */ true>(
blk_coord, blk_coord,
@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv local_split_kv
); );
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
} }
@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_tma<false>( load_tma<false>(
blk_coord, blk_coord,
@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv local_split_kv
); );
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
} }
@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
mma(blk_coord, mma(blk_coord,
problem_shape, problem_shape,
@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_producer_state, pipeline_mma_s, pipeline_mma_s_producer_state,
pipeline_p_mma, pipeline_p_mma_consumer_state, pipeline_p_mma, pipeline_p_mma_consumer_state,
pipeline_mma_o, pipeline_mma_o_producer_state, pipeline_mma_o, pipeline_mma_o_producer_state,
local_split_kv local_split_kv
); );
} }
} }
@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto split_kv = params.split_kv; auto split_kv = params.split_kv;
auto local_split_kv = split_kv; auto local_split_kv = split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
compute( compute(
blk_coord, blk_coord,
@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_consumer_state, pipeline_mma_s, pipeline_mma_s_consumer_state,
pipeline_p_mma, pipeline_p_mma_producer_state, pipeline_p_mma, pipeline_p_mma_producer_state,
pipeline_mma_o, pipeline_mma_o_consumer_state, pipeline_mma_o, pipeline_mma_o_consumer_state,
local_split_kv local_split_kv
); );
} }
@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
cutlass::arch::NamedBarrier( cutlass::arch::NamedBarrier(
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
kNamedBarrierEpilogue kNamedBarrierEpilogue
).arrive(); ).arrive_and_wait();
return; return;
} }

View File

@ -56,3 +56,11 @@ void cp_gather_cache(
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1] torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt); int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);

View File

@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cfloat> // FLT_MIN
#include <map> #include <map>
#include <vector> #include <vector>
@ -396,6 +397,180 @@ __global__ void concat_and_cache_mla_kernel(
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
} }
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_ds_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int entry_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int64_t dst_idx_start =
block_idx * block_stride + block_offset * entry_stride;
// Create 4 tile scales in shared memory
__shared__ float smem[20];
float* shard_abs_max = smem;
float* tile_scales = smem + 16;
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
// Cast kv_cache to 16_bit for RoPE values
scalar_t* kv_cache_16bit =
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
// The last 64 threads handle the RoPE part
if (threadIdx.x >= kv_lora_rank) {
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
// RoPE values start after the packed 8-bit NoPE values and the
// 32-bit scales
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
kv_cache_16bit[dst_idx] = k_pe[src_idx];
return;
}
// Determine the scale for each chunk of NoPE
const int16_t tile_idx = threadIdx.x >> 7;
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
const int16_t lane_idx = threadIdx.x & 31;
// Load the NoPE element for this thread into registers
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
const scalar_t src_val = kv_c[src_idx];
// Warp-level reduction to find the max absolute value in the warp
float max_abs = fabsf(src_val);
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2) {
#ifdef USE_ROCM
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
#else
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
#endif
}
// The first lane of each warp in each tile writes the max_abs of this part
// of the tile to shared memory
if (lane_idx == 0) {
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
}
__syncthreads();
// The first lane of the first warp in each tile computes the scale for the
// tile and writes it to shared memory and to kv_cache
if (warp_idx == 0 && lane_idx == 0) {
float4 shard_abs_max_vec =
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
448.f;
// Avoid division by zero in `scaled_convert`
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
}
__syncthreads();
// Now all threads in the block scale and write their element
const float scale_val = tile_scales[tile_idx];
const int64_t dst_idx = dst_idx_start + threadIdx.x;
kv_cache[dst_idx] =
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
src_val, scale_val);
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void indexer_k_quant_and_cache_kernel(
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int head_dim, // dimension of each head
const int quant_block_size, // quantization block size
const int cache_block_size, // cache block size
const int cache_stride, // stride for each token in kv_cache
const bool use_ue8m0 // use ue8m0 scale format
) {
constexpr int VEC_SIZE = 4;
const int64_t token_idx = blockIdx.x;
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
threadIdx.y * blockDim.x + threadIdx.x) *
VEC_SIZE;
const int64_t slot_idx = slot_mapping[token_idx];
const int64_t block_idx = slot_idx / cache_block_size;
const int64_t block_offset = slot_idx % cache_block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
return;
}
float2 k_val = (reinterpret_cast<const float2*>(
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
float amax = 0.0f;
for (int i = 0; i < VEC_SIZE; i++) {
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
}
#ifndef USE_ROCM
__syncwarp();
#endif
// Reduced amax
for (int mask = 16; mask > 0; mask /= 2) {
#ifdef USE_ROCM
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
#else
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
#endif
}
#ifndef USE_ROCM
__syncwarp();
#endif
float scale = fmaxf(amax, 1e-4) / 448.0f;
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
block_offset * head_dim + head_dim_idx;
for (int i = 0; i < VEC_SIZE; i++) {
kv_cache[dst_offset + i] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
}
if (threadIdx.x == 0) {
const int64_t dst_scale_idx =
block_idx * cache_block_size * cache_stride +
cache_block_size * head_dim +
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
}
}
} // namespace vllm } // namespace vllm
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
@ -438,7 +613,7 @@ void reshape_and_cache(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE) CALL_RESHAPE_AND_CACHE);
} }
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
@ -509,6 +684,18 @@ void reshape_and_cache_flash(
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr())); reinterpret_cast<const float*>(scale.data_ptr()));
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla( void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim] torch::Tensor& k_pe, // [num_tokens, pe_dim]
@ -531,20 +718,44 @@ void concat_and_cache_mla(
int pe_dim = k_pe.size(1); int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1); int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); if (kv_cache_dtype == "fp8_ds_mla") {
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
TORCH_CHECK(kv_c.itemsize() == 2,
"kv_c.itemsize() must be 2 for fp8_ds_mla");
TORCH_CHECK(k_pe.itemsize() == 2,
"k_pe.itemsize() must be 2 for fp8_ds_mla");
} else {
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
}
int kv_c_stride = kv_c.stride(0); int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0); int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0); int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1); int entry_stride = kv_cache.stride(1);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, if (kv_cache_dtype == "fp8_ds_mla") {
CALL_CONCAT_AND_CACHE_MLA); dim3 grid(num_tokens);
// For the NoPE part, each tile of 128 elements is handled by 4 warps
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
// The first thread of the first warp in each tile writes the scale
// value for the tile. The RoPE part (last 64 elements) is handled
// by another 2 warps (64 threads).
// So in total, we use 18 warps (576 threads) per block.
dim3 block(576);
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_DS_MLA);
} else {
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}
} }
namespace vllm { namespace vllm {
@ -922,3 +1133,42 @@ void cp_gather_cache(
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
} }
} }
// Macro to dispatch the kernel based on the data type.
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(k.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
cache_block_size, cache_stride, use_ue8m0);
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt) {
int num_tokens = k.size(0);
int head_dim = k.size(1);
int cache_block_size = kv_cache.size(1);
int cache_stride = kv_cache.size(2);
bool use_ue8m0 = scale_fmt == "ue8m0";
TORCH_CHECK(k.device() == kv_cache.device(),
"k and kv_cache must be on the same device");
TORCH_CHECK(k.device() == slot_mapping.device(),
"k and slot_mapping must be on the same device");
TORCH_CHECK(head_dim % quant_block_size == 0,
"head_dim must be divisible by quant_block_size");
constexpr int vec_size = 4;
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
(quant_block_size * vec_size));
dim3 block(32, vec_size);
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
CALL_INDEXER_K_QUANT_AND_CACHE);
}

View File

@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
TORCH_CHECK(false, \ TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \ "Unsupported input type of kv cache: ", SRC_DTYPE); \
} \ } \
} else if (KV_DTYPE == "fp8_ds_mla") { \
if (SRC_DTYPE == at::ScalarType::Float) { \
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::Half) { \
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
} else { \
TORCH_CHECK(false, \
"Unsupported input type of kv cache: ", SRC_DTYPE); \
} \
} else { \ } else { \
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
} \ } \

View File

@ -713,6 +713,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
"slot_mapping, "
"int quant_block_size, str kv_cache_dtype) -> ()");
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);
} }
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {

View File

@ -14,6 +14,11 @@ ARG PYTHON_VERSION=3.12
# #
# Example: # Example:
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 # docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
# Important: We build with an old version of Ubuntu to maintain broad
# compatibility with other Linux OSes. The main reason for this is that the
# glibc version is baked into the distro, and binaries built with one glibc
# version are not backwards compatible with OSes that use an earlier version.
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
# TODO: Restore to base image after FlashInfer AOT wheel fixed # TODO: Restore to base image after FlashInfer AOT wheel fixed
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
@ -75,34 +80,19 @@ ARG TARGETPLATFORM
ARG INSTALL_KV_CONNECTORS=false ARG INSTALL_KV_CONNECTORS=false
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ARG DEADSNAKES_MIRROR_URL
ARG DEADSNAKES_GPGKEY_URL
ARG GET_PIP_URL ARG GET_PIP_URL
# Install Python and other dependencies # Install system dependencies and uv, then create Python virtual environment
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \ && apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \ && apt-get install -y ccache software-properties-common git curl sudo python3-pip \
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \ && curl -LsSf https://astral.sh/uv/install.sh | sh \
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \ && $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
mkdir -p -m 0755 /etc/apt/keyrings ; \ && rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \ && ln -s /opt/venv/bin/python3 /usr/bin/python3 \
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \ && ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \ && ln -s /opt/venv/bin/pip /usr/bin/pip \
fi ; \
else \
for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
done ; \
fi \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
&& python3 --version && python3 -m pip --version && python3 --version && python3 -m pip --version
ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_INDEX_URL UV_INDEX_URL
@ -111,9 +101,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
# Install uv for faster pip installs # Activate virtual environment and add uv to PATH
RUN --mount=type=cache,target=/root/.cache/uv \ ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
python3 -m pip install uv ENV VIRTUAL_ENV="/opt/venv"
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
@ -142,7 +132,7 @@ WORKDIR /workspace
COPY requirements/common.txt requirements/common.txt COPY requirements/common.txt requirements/common.txt
COPY requirements/cuda.txt requirements/cuda.txt COPY requirements/cuda.txt requirements/cuda.txt
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/cuda.txt \ uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# cuda arch list used by torch # cuda arch list used by torch
@ -172,7 +162,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
ENV UV_LINK_MODE=copy ENV UV_LINK_MODE=copy
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt \ uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
COPY . . COPY . .
@ -269,7 +259,7 @@ COPY requirements/lint.txt requirements/lint.txt
COPY requirements/test.txt requirements/test.txt COPY requirements/test.txt requirements/test.txt
COPY requirements/dev.txt requirements/dev.txt COPY requirements/dev.txt requirements/dev.txt
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt \ uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
#################### DEV IMAGE #################### #################### DEV IMAGE ####################
@ -404,6 +394,9 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
fi fi
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
uv pip install --system cuda-python==$(echo $CUDA_VERSION | cut -d. -f1,2) pynvml==$(echo $CUDA_VERSION | cut -d. -f1) nvidia-nvshmem-cu$(echo $CUDA_VERSION | cut -d. -f1)
# Build AOT kernels # Build AOT kernels
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
python3 -m flashinfer.aot python3 -m flashinfer.aot

View File

@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0
# #
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
# prepare basic build environment # prepare basic build environment
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
ARG CUDA_VERSION=12.8.0 ARG CUDA_VERSION=12.8.0
ARG PYTHON_VERSION=3.12 ARG PYTHON_VERSION=3.12
ARG TARGETPLATFORM ARG TARGETPLATFORM

View File

@ -6,6 +6,13 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes, We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests. and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
!!! tip
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions.
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
## Offline Inference ## Offline Inference
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:

View File

@ -60,6 +60,15 @@ Key points from the PyTorch security guide:
- Implement proper authentication and authorization for management interfaces - Implement proper authentication and authorization for management interfaces
- Follow the principle of least privilege for all system components - Follow the principle of least privilege for all system components
### 4. **Restrict Domains Access for Media URLs:**
Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP
redirects from being followed to bypass domain restrictions.
## Security and Firewalls: Protecting Exposed vLLM Systems ## Security and Firewalls: Protecting Exposed vLLM Systems
While vLLM is designed to allow unsafe network services to be isolated to While vLLM is designed to allow unsafe network services to be isolated to

View File

@ -54,6 +54,7 @@ def parse_args():
"--method", "--method",
type=str, type=str,
default="eagle", default="eagle",
choices=["ngram", "eagle", "eagle3", "mtp"],
) )
parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--num-spec-tokens", type=int, default=2)
parser.add_argument("--prompt-lookup-max", type=int, default=5) parser.add_argument("--prompt-lookup-max", type=int, default=5)
@ -118,9 +119,9 @@ def main(args):
"prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min, "prompt_lookup_min": args.prompt_lookup_min,
} }
elif args.method.endswith("mtp"): elif args.method == "mtp":
speculative_config = { speculative_config = {
"method": args.method, "method": "mtp",
"num_speculative_tokens": args.num_spec_tokens, "num_speculative_tokens": args.num_spec_tokens,
} }
else: else:

View File

@ -322,6 +322,8 @@ class precompiled_wheel_utils:
"vllm/_C.abi3.so", "vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so", "vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so", "vllm/_flashmla_C.abi3.so",
"vllm/_flashmla_extension_C.abi3.so",
"vllm/_sparse_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/cumem_allocator.abi3.so", "vllm/cumem_allocator.abi3.so",
@ -589,6 +591,8 @@ if _is_cuda():
# not targeting a hopper system # not targeting a hopper system
ext_modules.append( ext_modules.append(
CMakeExtension(name="vllm._flashmla_C", optional=True)) CMakeExtension(name="vllm._flashmla_C", optional=True))
ext_modules.append(
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
if _build_custom_ops(): if _build_custom_ops():

View File

@ -191,7 +191,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
use_mla=False,
), ),
layer_names=[self.attn.layer_name], layer_names=[self.attn.layer_name],
vllm_config=self.vllm_config, vllm_config=self.vllm_config,

View File

@ -45,6 +45,7 @@ class MockModelConfig:
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
skip_tokenizer_init: bool = False skip_tokenizer_init: bool = False

View File

@ -240,6 +240,7 @@ class MockModelConfig:
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)

View File

@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages, parse_chat_messages,
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format, resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template) resolve_hf_chat_template)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id", "add_generation_prompt",
"continue_final_message", "tools"
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking", "add_generation_prompt",
"continue_final_message", "tools"
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = ([{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}])
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# NOTE: Qwen2-Audio default chat template is specially defined inside # NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json` # processor class instead of using `tokenizer_config.json`
# yapf: disable # yapf: disable

View File

@ -593,6 +593,119 @@ def test_concat_and_cache_mla(
torch.testing.assert_close(kv_cache, ref_kv_cache) torch.testing.assert_close(kv_cache, ref_kv_cache)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_concat_and_cache_ds_mla(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
if dtype.itemsize != 2:
pytest.skip("ds_mla only supports 16-bit input")
kv_cache_dtype = "fp8_ds_mla"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
kv_cache_dtype=kv_cache_dtype,
device=device)
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
tile_data = torch.zeros(128, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_cache_slice = ref_cache[block_idx, block_offset]
ref_cache_16bit = ref_cache_slice.view(dtype)
ref_cache_32bit = ref_cache_slice.view(torch.float32)
kv_c_data = kv_c[i]
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = (tile_idx + 1) * 128
tile_data[:] = kv_c_data[tile_start:tile_end]
# tile_scale = tile_data.amax().to(torch.float32) / 448.
# NOTE: Using torch's amax() gives different results,
# so this must be manually computed.
tile_data_float = tile_data.to(torch.float32)
manual_max = abs(tile_data_float[0])
for j in range(1, 128):
manual_max = max(manual_max, abs(tile_data_float[j]))
tile_scale = manual_max / 448.
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
tile_data,
tile_scale.item(),
kv_dtype="fp8")
for j in range(qk_rope_head_dim):
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
kv_cache_slice = kv_cache[block_idx, block_offset]
ref_cache_slice = ref_cache[block_idx, block_offset]
kv_nope = kv_cache_slice[:kv_lora_rank]
ref_nope = ref_cache_slice[:kv_lora_rank]
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
ref_scales = ref_cache_slice.view(
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)

View File

@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import pytest
import torch
from vllm.platforms import current_platform
from vllm.utils import cdiv, has_deep_gemm
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
fp8_paged_mqa_logits, get_num_sms,
get_paged_mqa_logits_metadata)
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
# x: (num_blocks, block_size, 1, head_dim)
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
x_fp8 = torch.empty(
(num_blocks, block_size * (head_dim + 4)),
device=x.device,
dtype=torch.uint8,
)
x_fp8[:, :block_size * head_dim] = x_scaled.view(
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
x_fp8[:,
block_size * head_dim:] = sf.view(num_blocks,
block_size).view(dtype=torch.uint8)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
def per_custom_dims_cast_to_fp8(
x: torch.Tensor, dims: tuple,
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
chunk_size = seq_len // 2
cp_size = seq_len_kv // seq_len
cp_id = cp_size // 3
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
for i in range(chunk_size):
ke[i] = cp_id * chunk_size + i
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
return ks, ke
def _ref_fp8_mqa_logits(
q: torch.Tensor,
kv: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
):
seq_len_kv = kv.shape[0]
k = kv
q = q.float()
k = k.float()
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
score = torch.einsum("mhd,and->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
def test_deepgemm_fp8_mqa_logits():
torch.manual_seed(0)
random.seed(0)
num_heads, head_dim = 32, 128
for seq_len in (512, ):
for seq_len_kv in (1024, ):
for disable_cp in (False, True):
q = torch.randn(
seq_len,
num_heads,
head_dim,
device="cuda",
dtype=torch.bfloat16,
)
kv = torch.randn(seq_len_kv,
head_dim,
device="cuda",
dtype=torch.bfloat16)
weights = torch.randn(seq_len,
num_heads,
device="cuda",
dtype=torch.float32)
if disable_cp:
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
ke = torch.arange(seq_len, dtype=torch.int,
device="cuda") + (seq_len_kv - seq_len)
else:
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
ref_logits = _ref_fp8_mqa_logits(
q=q,
kv=kv,
weights=weights,
cu_seqlen_ks=ks,
cu_seqlen_ke=ke,
)
ref_neginf_mask = ref_logits == float("-inf")
neginf_mask = logits == float("-inf")
assert torch.equal(neginf_mask, ref_neginf_mask)
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"
def _ref_fp8_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
):
batch_size, next_n, _, _ = q.size()
_, block_size, _, _ = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
context_lens_list = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens_list[i]
q_offsets = torch.arange(context_len - next_n,
context_len,
device="cuda")
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
0, 1).contiguous())
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(
block_rk * block_size,
(block_rk + 1) * block_size,
device="cuda",
)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
<= q_offsets[:, None])
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n:(i + 1) * next_n,
block_rk * block_size:(block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
float("-inf"))
return logits
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="SM90 and SM100 only")
def test_deepgemm_fp8_paged_mqa_logits():
torch.manual_seed(0)
random.seed(0)
max_model_len = 4096
for batch_size, next_n in [(4, 1), (2, 2)]:
for heads, index_dim in [(32, 128)]:
for avg_kv in (2048, ):
num_blocks, blocksize = max_model_len * 2, 64
q = torch.randn(
(batch_size, next_n, heads, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
kv_cache = torch.randn(
(num_blocks, blocksize, 1, index_dim),
device="cuda",
dtype=torch.bfloat16,
)
weights = torch.randn(
(batch_size * next_n, heads),
device="cuda",
dtype=torch.float32,
)
context_lens = (torch.randint(int(0.8 * avg_kv),
int(1.2 * avg_kv),
(batch_size, )).cuda().to(
torch.int32))
max_block_len = ((context_lens.max().item() + blocksize - 1) //
blocksize * blocksize)
block_tables = torch.zeros(
(batch_size, max_block_len),
device="cuda",
dtype=torch.int32,
)
counter = 0
block_idx_pool = list(range(num_blocks))
random.shuffle(block_idx_pool)
for i in range(batch_size):
ctx_len = int(context_lens[i].item())
for j in range((ctx_len + blocksize - 1) // blocksize):
block_tables[i][j] = block_idx_pool[counter]
counter += 1
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
schedule_metadata = get_paged_mqa_logits_metadata(
context_lens, blocksize, get_num_sms())
logits = fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
context_lens,
block_tables,
schedule_metadata,
max_model_len,
)
ref_logits = _ref_fp8_paged_mqa_logits(
q,
kv_cache,
weights,
context_lens,
block_tables,
max_model_len,
)
positions = (torch.arange(max_model_len,
device="cuda").unsqueeze(0).expand(
batch_size * next_n, -1))
row_indices = (
torch.arange(batch_size * next_n, device="cuda") // next_n)
next_n_offset = (
torch.arange(batch_size * next_n, device="cuda") % next_n)
mask = positions <= (context_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
logits = logits.masked_fill(~mask, 0)
ref_logits = ref_logits.masked_fill(~mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"

View File

@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
descale_k = None descale_k = None
def flash_mla(): def flash_mla():
return flash_mla_with_kvcache( return flash_mla_with_kvcache(q,
q, blocked_k,
blocked_k, block_table,
block_table, cache_seqlens,
cache_seqlens, dv,
dv, tile_scheduler_metadata,
tile_scheduler_metadata, num_splits,
num_splits, causal=causal,
causal=causal, descale_q=descale_q,
descale_q=descale_q, descale_k=descale_k)
descale_k=descale_k,
)
def scaled_dot_product_attention(query, key, value, is_causal=False): def scaled_dot_product_attention(query, key, value, is_causal=False):
query = query.float() query = query.float()

View File

@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
def _cuda_sm90_available() -> bool:
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major == 9
def test_sparse_flashmla_metadata_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 128
num_heads_k = 1
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
topk = 128
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
assert tile_md.dtype == torch.int32
assert num_splits.dtype == torch.int32
def test_sparse_flashmla_decode_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 1
head_dim_k = 576
head_dim_v = 512
num_heads_k = 1
page_block_size = 64
bytes_per_token = 656
topk = 128
# Metadata
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
# q_heads_per_hk = num_heads_q // num_heads_k
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
q_seq_per_hk,
num_heads_k,
num_heads_q=num_heads_q,
topk=topk,
is_fp8_kvcache=True)
# Inputs
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
dtype=torch.bfloat16,
device=device)
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
dtype=torch.uint8,
device=device)
indices = torch.zeros((batch_size, seqlen_q, topk),
dtype=torch.int32,
device=device)
block_table = torch.zeros((batch_size, 128),
dtype=torch.int32,
device=device)
out, lse = fm.flash_mla_with_kvcache(q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_md,
num_splits,
indices=indices,
is_fp8_kvcache=True)
assert out.shape[0] == batch_size
assert out.shape[-1] == head_dim_v
assert lse.shape[0] == batch_size
def test_sparse_flashmla_prefill_smoke():
import vllm.attention.ops.flashmla as fm
ok, reason = fm.is_flashmla_supported()
if not ok or not _cuda_sm90_available():
pytest.skip(reason or "SM90 not available")
device = torch.device("cuda")
s_q = 1
s_kv = 1
h_q = 64 # kernel expects multiple of 64
h_kv = 1
d_qk = 576
d_v = 512
topk = 128
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
d_v)
assert out.shape == (s_q, h_q, d_v)
assert max_logits.shape == (s_q, h_q)
assert lse.shape == (s_q, h_q)

View File

@ -0,0 +1,245 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.testing import assert_close
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
def test_pack_seq_basic_fp8():
"""Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test cases with 3D tensors (N, H, D)
test_cases = [
(6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4)
(10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8)
(20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32)
]
for N, H, D, B, lengths_list in test_cases:
# Create input tensor with small values for fp8
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor(lengths_list, device=device)
# Pack the data
packed = pack_seq_triton(x, lengths)
# Check output shape and properties
expected_shape = (B, max(lengths_list), H, D)
assert packed.shape == expected_shape
assert packed.dtype == dtype
assert packed.device == x.device
# Check that valid data is preserved (within fp8 precision)
for b in range(B):
start_idx = sum(lengths_list[:b])
seq_len = lengths_list[b]
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
actual_data = packed[b, :seq_len].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
def test_pack_seq_custom_padding_fp8():
"""Test pack_seq_triton with custom padding values for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 20, 8, 16, 2
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
# Test with different padding values
for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
result = pack_seq_triton(x, lengths, pad_value=pad_value)
# Check valid data
for b in range(B):
start_idx = b * 10
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
actual_data = result[b, :10].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
# Check padding (fp8 has limited range, so check for large values)
padded_data = result[:, 10:].to(torch.float32)
if pad_value < 0:
assert torch.all(padded_data < -50) # Large negative values
elif pad_value > 0:
assert torch.all(padded_data > 50) # Large positive values
else:
assert torch.allclose(padded_data,
torch.zeros_like(padded_data),
atol=1e-2)
def test_pack_seq_default_negative_inf_padding_fp8():
"""Test that pack_seq_triton uses -inf padding by default for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# B = 2
N, H, D = 20, 8, 16
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
result = pack_seq_triton(x, lengths)
# Check that padding is large negative values (fp8 representation of -inf)
padded_data = result[:, 10:].to(torch.float32)
assert torch.all(
padded_data < -100) # fp8 -inf is represented as large negative number
def test_pack_seq_edge_cases_fp8():
"""Test pack_seq_triton with edge cases for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test with single batch element
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([10], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (1, 10, 8, 16)
# Test with very short sequences
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([1, 1, 1], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (3, 1, 4, 8)
# Test with different sequence lengths
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([5, 7, 3], device=device)
result = pack_seq_triton(x, lengths)
assert result.shape == (3, 7, 8, 16)
def test_pack_seq_different_block_sizes_fp8():
"""Test pack_seq_triton with different block sizes for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 100, 16, 32, 4
lengths = torch.tensor([25, 25, 25, 25], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
# Test different block sizes
for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
assert result.shape == (B, 25, H, D)
# Check that valid data is preserved (within fp8 precision)
for b in range(B):
start_idx = b * 25
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
actual_data = result[b, :25].to(torch.float32)
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
def test_pack_seq_shape_consistency():
"""Test that pack_seq_triton maintains shape consistency."""
device = "cuda"
dtype = torch.float8_e4m3fn
N, H, D, B = 20, 8, 16, 2
lengths = torch.tensor([10, 10], device=device)
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
result = pack_seq_triton(x, lengths)
# Check shape consistency
assert result.shape[0] == B # Batch dimension
assert result.shape[1] == lengths.max().item() # Max sequence length
assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved
def test_pack_unpack_roundtrip_fp8():
"""Test that pack -> unpack gives us back the original data for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test cases with 3D tensors
test_cases = [
(6, 8, 4, 2, [3, 3]),
(10, 4, 8, 3, [2, 4, 4]),
(20, 16, 32, 4, [5, 5, 5, 5]),
(15, 8, 16, 3, [7, 5, 3]),
]
for N, H, D, B, lengths_list in test_cases:
# Create input tensor with small values for fp8
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor(lengths_list, device=device)
# Pack the data
packed = pack_seq_triton(x, lengths)
# Unpack the data
unpacked = unpack_seq_triton(packed, lengths)
# Check that we get back the original data (within fp8 precision)
assert unpacked.shape == x.shape
x_f32 = x.to(torch.float32)
unpacked_f32 = unpacked.to(torch.float32)
assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
# Unpack without explicit start locations (computed in kernel)
unpacked_with_loc = unpack_seq_triton(packed, lengths)
assert_close(x_f32,
unpacked_with_loc.to(torch.float32),
rtol=1e-3,
atol=1e-2)
def test_unpack_seq_triton_edge_cases_fp8():
"""Test unpack function with edge cases for fp8."""
device = "cuda"
dtype = torch.float8_e4m3fn
# Test with single batch element
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([10], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
# Test with very short sequences
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([1, 1, 1], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
# Only compare the first 3 elements that were actually packed
assert_close(x[:3].to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
x = x.to(dtype=dtype)
lengths = torch.tensor([5, 7, 3], device=device)
packed = pack_seq_triton(x, lengths)
unpacked = unpack_seq_triton(packed, lengths)
assert unpacked.shape == x.shape
assert_close(x.to(torch.float32),
unpacked.to(torch.float32),
rtol=1e-1,
atol=1e-2)

View File

@ -207,6 +207,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
trust_remote_code=True), trust_remote_code=True),
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT", "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
min_transformers_version="4.54"), min_transformers_version="4.54"),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",

View File

@ -8,7 +8,8 @@ import pytest
from vllm import LLM from vllm import LLM
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
get_kv_cache_configs)
from vllm.v1.engine.core import EngineCore as V1EngineCore from vllm.v1.engine.core import EngineCore as V1EngineCore
from ..utils import create_new_process_for_each_test from ..utils import create_new_process_for_each_test
@ -62,11 +63,13 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# Avoid calling model.forward() # Avoid calling model.forward()
def _initialize_kv_caches_v1(self, vllm_config): def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs() kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_configs( kv_cache_configs = get_kv_cache_configs(
vllm_config, vllm_config,
kv_cache_specs, kv_cache_specs,
[10 * GiB_bytes], [10 * GiB_bytes],
)[0] )
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
kv_cache_configs)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config return 1, 0, scheduler_kv_cache_config

View File

@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes()) @pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: dict[str, Image.Image], async def test_fetch_image_base64(url_images: dict[str, Image.Image],
raw_image_url: str, suffix: str): raw_image_url: str, suffix: str):
connector = MediaConnector() connector = MediaConnector(
# Domain restriction should not apply to data URLs.
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
url_image = url_images[raw_image_url] url_image = url_images[raw_image_url]
try: try:
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
modality_idxs = argsort_mm_positions(mm_positions) modality_idxs = argsort_mm_positions(mm_positions)
assert modality_idxs == expected_modality_idxs assert modality_idxs == expected_modality_idxs
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_allowed_media_domains(video_url: str, num_frames: int):
connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}},
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
with pytest.raises(ValueError):
_, _ = connector.fetch_video(disallowed_url)
with pytest.raises(ValueError):
_, _ = await connector.fetch_video_async(disallowed_url)

View File

@ -26,5 +26,5 @@ class DummyPlatform(Platform):
def get_attn_backend_cls(self, backend_name, head_size, dtype, def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink): has_sink, use_sparse):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 MLA backends without GPUModelRunner dependency.""" """Tests for v1 MLA backends without GPUModelRunner dependency."""
from typing import Optional, Union
import pytest import pytest
import torch import torch
@ -10,6 +11,7 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
get_attention_backend) get_attention_backend)
from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import FullAttentionSpec
@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache(
device: torch.device, device: torch.device,
num_blocks: int, num_blocks: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True) -> torch.Tensor: randomize_blocks: bool = True,
kv_cache_dtype: Optional[str] = None,
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
"""Create and prepopulate an MLA KV cache with context data. """Create and prepopulate an MLA KV cache with context data.
Args: Args:
@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: Common attention metadata common_attn_metadata: Common attention metadata
randomize_blocks: Whether to randomly permute blocks randomize_blocks: Whether to randomly permute blocks
or use sequential order or use sequential order
kv_cache_dtype: Optional kv cache dtype string. When set to
"fp8_ds_mla" the cache is populated using the
fp8 DeepSeek MLA layout via concat_and_cache_mla.
scale: Scaling factor forwarded to concat_and_cache_mla when the
fp8 cache layout is requested.
Returns: Returns:
MLA KV cache tensor MLA KV cache tensor
@ -105,23 +114,61 @@ def create_and_prepopulate_kv_cache(
block_table = common_attn_metadata.block_table_tensor block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
# Create MLA KV cache: (num_blocks, block_size, head_size) use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
kv_cache = torch.empty(num_blocks,
block_size, if use_fp8_ds_mla:
head_size, if not kv_c_contexts:
dtype=dtype, raise ValueError("kv_c_contexts cannot be empty when using"
device=device) " fp8_ds_mla cache dtype")
kv_cache_flat = kv_cache.view(-1, head_size) kv_lora_rank = kv_c_contexts[0].shape[-1]
rope_dim = k_pe_contexts[0].shape[-1]
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
kv_cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
device=device)
scale_tensor = (scale
if isinstance(scale, torch.Tensor) else torch.tensor(
scale, dtype=torch.float32, device=device))
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
else:
# Create MLA KV cache: (num_blocks, block_size, head_size)
kv_cache = torch.empty(num_blocks,
block_size,
head_size,
dtype=dtype,
device=device)
kv_cache_flat = kv_cache.view(-1, head_size)
# Populate the cache with the context tokens # Populate the cache with the context tokens
# Start from block_id=1 since block_id=0 is considered the null block # Start from block_id=1 since block_id=0 is considered the null block
start_block_idx = 1 start_block_idx = 1
for i in range(batch_size): for i in range(batch_size):
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) context_len = kv_c_context.shape[0]
if context_len == 0:
start_block_idx += cdiv(int(seq_lens[i]), block_size)
continue
start = start_block_idx * block_size start = start_block_idx * block_size
end = start + kv_context.shape[0]
kv_cache_flat[start:end, ...] = kv_context if use_fp8_ds_mla:
slots = torch.arange(context_len, device=device,
dtype=torch.long) + start
ops.concat_and_cache_mla(
kv_c_context,
k_pe_context.squeeze(1),
kv_cache,
slots,
kv_cache_dtype="fp8_ds_mla",
scale=scale_tensor,
)
else:
kv_context = torch.cat(
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
end = start + kv_context.shape[0]
kv_cache_flat[start:end, ...] = kv_context
# Stay block aligned and allocate enough blocks for the new tokens # Stay block aligned and allocate enough blocks for the new tokens
start_block_idx += cdiv(int(seq_lens[i]), block_size) start_block_idx += cdiv(int(seq_lens[i]), block_size)

View File

@ -0,0 +1,448 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the FlashMLA sparse backend utilities."""
import math
from types import MethodType, SimpleNamespace
import numpy as np
import pytest
import torch
from tests.v1.attention.test_mla_backends import (
BATCH_SPECS, BatchSpec, MockAttentionLayer,
create_and_prepopulate_kv_cache)
from tests.v1.attention.utils import (create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config)
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl, FlashMLASparseMetadata)
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
for name in [
"mixed_small",
"mixed_medium",
"small_prefill",
"medium_prefill",
"single_prefill",
]
}
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
query_lens=[256] * 2)
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens=[256] * 2, query_lens=[256] * 2)
def _dequantize_fp8_ds_mla_entry(
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
# The first kv_lora_rank bytes store FP8 latent values with one scale per
# 128 element tile written as float32 right after the latent payload.
scales = cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
latent = torch.empty(kv_lora_rank,
dtype=torch.float16,
device=cache_slice.device)
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = tile_start + 128
ops.convert_fp8(latent[tile_start:tile_end],
cache_slice[tile_start:tile_end],
float(scales[tile_idx].item()),
kv_dtype="fp8")
latent = latent.to(dtype)
rope_offset = kv_lora_rank // 2 + 8
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
return latent, rope_vals.clone()
def _quantize_dequantize_fp8_ds_mla(
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
if kv_c.numel() == 0:
return kv_c.clone(), k_pe.clone()
kv_lora_rank = kv_c.shape[-1]
rope_dim = k_pe.shape[-1]
num_tokens = kv_c.shape[0]
num_blocks = max(1, math.ceil(num_tokens / block_size))
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
tmp_cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
device=kv_c.device)
slot_mapping = torch.arange(num_tokens,
dtype=torch.long,
device=kv_c.device)
ops.concat_and_cache_mla(kv_c,
k_pe,
tmp_cache,
slot_mapping,
kv_cache_dtype="fp8_ds_mla",
scale=scale)
dequant_kv_c = torch.empty_like(kv_c)
dequant_k_pe = torch.empty_like(k_pe)
for token_idx in range(num_tokens):
slot = slot_mapping[token_idx].item()
block_idx = slot // block_size
block_offset = slot % block_size
cache_slice = tmp_cache[block_idx, block_offset]
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
dequant_kv_c[token_idx] = latent
dequant_k_pe[token_idx] = rope_vals
return dequant_kv_c, dequant_k_pe
def test_sparse_backend_metadata_registration():
backend = FlashMLASparseBackend
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
assert backend.get_metadata_cls() is FlashMLASparseMetadata
assert backend.get_impl_cls() is FlashMLASparseImpl
dtype_list = backend.get_supported_dtypes()
assert torch.bfloat16 in dtype_list
shape = backend.get_kv_cache_shape(num_blocks=2,
block_size=64,
num_kv_heads=1,
head_size=576)
assert shape == (2, 64, 576)
def test_sparse_decode_metadata_filters_prefill_indices():
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
metadata = FlashMLASparseDecodeAndContextMetadata(
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
num_splits=torch.tensor([1, 1], dtype=torch.int32),
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
prefill_context_lengths=prefill_context_lengths,
)
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
context_indices, new_token_indices = metadata.filter_prefill_indices(
indices)
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
dtype=torch.int32)
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
dtype=torch.int32)
assert torch.equal(context_indices, expected_context)
assert torch.equal(new_token_indices, expected_new_tokens)
def test_sparse_impl_zero_fills_when_metadata_missing():
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
dummy_layer = object()
q = torch.zeros((2, 1, 3))
k_c = torch.zeros((2, 3))
k_pe = torch.zeros((2, 1, 1))
kv_cache = torch.zeros((1, 1, 1))
output = torch.ones((2, 4))
result = FlashMLASparseImpl.forward(impl,
dummy_layer,
q,
k_c,
k_pe,
kv_cache,
attn_metadata=None,
output=output)
assert result is output
assert torch.all(result == 0)
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name,
kv_cache_dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")
device = torch.device("cuda")
dtype = torch.bfloat16
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
# Model hyper-parameters (kept intentionally small for the unit test)
num_heads = 128
kv_lora_rank = 512
qk_nope_head_dim = 128
qk_rope_head_dim = 64
v_head_dim = 128
head_size = kv_lora_rank + qk_rope_head_dim
topk_tokens = 2048
max_seqlen = max(batch_spec.seq_lens)
total_cache_tokens = sum(batch_spec.seq_lens)
block_size = 64
vllm_config = create_vllm_config(
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
max_model_len=max_seqlen,
num_gpu_blocks=max(2048,
cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size)
model_config = vllm_config.model_config
model_config.hf_config = SimpleNamespace(
attn_module_list_cfg=[{
"topk_tokens": topk_tokens
}])
model_config.hf_text_config = SimpleNamespace(
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
model_type="deepseek_v2",
)
model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: num_heads, model_config)
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
model_config)
model_config.get_head_size = MethodType(lambda self: head_size,
model_config)
model_config.get_sliding_window = MethodType(lambda self: None,
model_config)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
torch.manual_seed(0)
scale = 1.0 / math.sqrt(head_size)
# Shared MLA projection weights to keep reference and backend in sync
W_UK = torch.randn(kv_lora_rank,
num_heads,
qk_nope_head_dim,
dtype=dtype,
device=device)
W_UV = torch.randn(kv_lora_rank,
num_heads,
v_head_dim,
dtype=dtype,
device=device)
# Build synthetic decode-only workload
seq_lens = batch_spec.seq_lens
query_lens = batch_spec.query_lens
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
kv_c_contexts, k_pe_contexts = [], []
reference_outputs = []
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
for i in range(batch_spec.batch_size):
s_len = seq_lens[i]
q_len = query_lens[i]
ctx_len = s_len - q_len
q_c = torch.rand(q_len,
num_heads,
qk_nope_head_dim + qk_rope_head_dim,
dtype=dtype,
device=device)
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
k_pe_full = torch.rand(s_len,
1,
qk_rope_head_dim,
dtype=dtype,
device=device)
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
kv_c_full,
k_pe_full.squeeze(1),
block_size=vllm_config.cache_config.block_size,
scale=kv_cache_scale,
)
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
attn_mask[:, ctx_len:] = causal_mask
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
all_q_vllm.append(q_c)
all_kv_c_vllm.append(kv_c_full[ctx_len:])
all_k_pe_vllm.append(k_pe_full[ctx_len:])
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
sdpa_reference = torch.cat(reference_outputs, dim=0)
vllm_config.cache_config.cache_dtype = kv_cache_dtype
common_attn_metadata = create_common_attn_metadata(
batch_spec,
vllm_config.cache_config.block_size,
device,
arange_block_indices=True)
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=vllm_config.cache_config.block_size,
head_size=head_size,
dtype=dtype,
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=False,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
scale=kv_cache_scale,
)
builder_cls = FlashMLASparseBackend.get_builder_cls()
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
metadata = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
seg_lengths = np.diff(starts)
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
starts[:-1], seg_lengths)
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
prefix_lengths = seq_lengths - seg_lengths
positions += np.repeat(prefix_lengths, seg_lengths)
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
topk = metadata.topk_tokens
debug_indices = torch.arange(topk, device=device,
dtype=torch.int32).unsqueeze(0)
token_positions = pos_gpu.unsqueeze(1)
causal_mask = (debug_indices <= token_positions)
debug_indices = torch.where(causal_mask, debug_indices,
torch.full_like(debug_indices, -1))
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
# buffer, so emulate that contract with a simple namespace mock.
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
-1).clone()
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
ok, reason = flashmla.is_flashmla_supported()
if not ok:
pytest.skip(reason)
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
output_size=num_heads *
(qk_nope_head_dim + v_head_dim),
bias=False).to(device=device,
dtype=dtype)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls()
impl = impl_cls(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer)
impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device)
out_buffer = torch.empty(metadata.num_actual_tokens,
num_heads * v_head_dim,
dtype=dtype,
device=device)
backend_output = impl.forward(layer,
query_vllm,
kv_c_vllm,
k_pe_vllm,
kv_cache,
metadata,
output=out_buffer)
assert backend_output.shape == sdpa_reference.shape
assert backend_output.dtype == sdpa_reference.dtype
assert torch.isfinite(backend_output).all()
torch.testing.assert_close(backend_output,
sdpa_reference,
rtol=0.5,
atol=0.5)
@pytest.mark.parametrize(
"seq_lens,max_buf,start,expected",
[
# Basic split: totals per chunk ≤ max_buf
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
# Non-zero start index
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would
# overflow
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
# All requests fit in a single chunk
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
# Large buffer with non-zero start
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
out = split_prefill_chunks(seq_lens, max_buf, start)
assert out == expected

View File

@ -168,7 +168,6 @@ def create_standard_kv_cache_spec(
vllm_config.parallel_config), vllm_config.parallel_config),
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
use_mla=vllm_config.model_config.use_mla,
sliding_window=vllm_config.model_config.get_sliding_window(), sliding_window=vllm_config.model_config.get_sliding_window(),
) )

View File

@ -24,7 +24,8 @@ from vllm.v1.core.kv_cache_utils import (
make_block_hash_with_group_id) make_block_hash_with_group_id)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec, KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec, KVCacheTensor, MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs) UniformTypeKVCacheSpecs)
from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request from vllm.v1.request import Request
@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16,
num_kv_heads=2, num_kv_heads=2,
head_size=64, head_size=64,
dtype=torch.float32, dtype=torch.float32,
use_mla=False,
sliding_window=None): sliding_window=None):
return FullAttentionSpec(block_size=block_size, return FullAttentionSpec(block_size=block_size,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window) sliding_window=sliding_window)
@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16,
num_kv_heads=2, num_kv_heads=2,
head_size=64, head_size=64,
dtype=torch.float32, dtype=torch.float32,
use_mla=False,
sliding_window=1): sliding_window=1):
return SlidingWindowSpec(block_size=block_size, return SlidingWindowSpec(block_size=block_size,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window) sliding_window=sliding_window)
@ -894,7 +891,6 @@ def test_merge_kv_cache_spec():
num_kv_heads=full_spec.num_kv_heads, num_kv_heads=full_spec.num_kv_heads,
head_size=full_spec.head_size, head_size=full_spec.head_size,
dtype=full_spec.dtype, dtype=full_spec.dtype,
use_mla=full_spec.use_mla,
sliding_window=1, sliding_window=1,
), ),
] ]
@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len,
num_kv_heads=32, num_kv_heads=32,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
use_mla=False,
) )
# Estimate the maximum model length, 16384 model_len need 8GB # Estimate the maximum model length, 16384 model_len need 8GB
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config():
num_kv_heads=32, num_kv_heads=32,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
use_mla=False,
) )
sliding_window_spec = SlidingWindowSpec( sliding_window_spec = SlidingWindowSpec(
@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config():
num_kv_heads=32, num_kv_heads=32,
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
use_mla=False,
sliding_window=1024, sliding_window=1024,
) )
@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config():
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec()) KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
], ],
) )
def new_mla_spec(cache_dtype_str=None):
return MLAAttentionSpec(block_size=16,
num_kv_heads=16,
head_size=64,
dtype=torch.float32,
cache_dtype_str=cache_dtype_str)
def test_merge_mla_spec():
kv_cache_specs = [
new_mla_spec(),
new_mla_spec(),
]
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
assert mla_spec == new_mla_spec()
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
]
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla")
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_mla_spec(cache_dtype_str=None),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
kv_cache_specs = [
new_kv_cache_spec(),
new_mla_spec(),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)
kv_cache_specs = [
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
new_kv_cache_spec(),
]
with pytest.raises(AssertionError):
kv_cache_specs[0].merge(kv_cache_specs)

View File

@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
kv_cache_groups=[ kv_cache_groups=[
KVCacheGroupSpec( KVCacheGroupSpec(
["layer"], ["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False), FullAttentionSpec(block_size, 1, 1, torch.float32),
) )
], ],
) )
@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
kv_cache_groups=[ kv_cache_groups=[
KVCacheGroupSpec( KVCacheGroupSpec(
["layer1"], ["layer1"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False), FullAttentionSpec(block_size, 1, 1, torch.float32),
), ),
KVCacheGroupSpec( KVCacheGroupSpec(
["layer2"], ["layer2"],
@ -98,7 +98,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
1, 1,
1, 1,
torch.float32, torch.float32,
False,
sliding_window=2 * block_size), sliding_window=2 * block_size),
), ),
KVCacheGroupSpec( KVCacheGroupSpec(
@ -107,7 +106,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
1, 1,
1, 1,
torch.float32, torch.float32,
False,
sliding_window=2 * block_size), sliding_window=2 * block_size),
), ),
], ],
@ -1338,7 +1336,6 @@ def test_eagle_with_sliding_window():
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
sliding_window=block_size, sliding_window=block_size,
use_mla=False,
) )
manager = KVCacheManager( manager = KVCacheManager(
KVCacheConfig( KVCacheConfig(

View File

@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix():
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
attention_chunk_size=4, attention_chunk_size=4,
use_mla=False,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
@ -100,7 +99,6 @@ def test_sliding_window_possible_cached_prefix():
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
sliding_window=4, sliding_window=4,
use_mla=False,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
@ -165,7 +163,6 @@ def test_chunked_local_attention_remove_skipped_blocks():
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
attention_chunk_size=4, attention_chunk_size=4,
use_mla=False,
) )
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
@ -217,7 +214,6 @@ def test_sliding_window_remove_skipped_blocks():
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
sliding_window=4, sliding_window=4,
use_mla=False,
) )
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
@ -285,7 +281,6 @@ def test_get_num_blocks_to_allocate():
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
sliding_window=4, # Placeholder value, not related to test result sliding_window=4, # Placeholder value, not related to test result
use_mla=False,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
@ -308,7 +303,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
head_size=1, head_size=1,
dtype=torch.float32, dtype=torch.float32,
attention_chunk_size=4, # Placeholder value, not related to test result attention_chunk_size=4, # Placeholder value, not related to test result
use_mla=False,
) )
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)

View File

@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.platforms import current_platform from vllm.platforms import current_platform
MTP_SIMILARITY_RATE = 0.8
def get_test_prompts(mm_enabled: bool): def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"] prompt_types = ["repeat", "sentence"]
@ -222,3 +224,66 @@ def test_eagle_correctness(
del spec_llm del spec_llm
torch.cuda.empty_cache() torch.cuda.empty_cache()
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
],
ids=["mimo", "deepseek"])
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_setup: tuple[str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using MTP speculative decoding.
model_setup: (method, model_name, tp_size)
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_MLA_DISABLE", "1")
method, model_name, tp_size = model_setup
ref_llm = LLM(model=model_name,
max_model_len=2048,
tensor_parallel_size=tp_size,
trust_remote_code=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
tensor_parallel_size=tp_size,
speculative_config={
"method": method,
"num_speculative_tokens": 1,
"max_model_len": 2048,
},
max_model_len=2048,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 80% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

View File

@ -836,8 +836,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
mock_spec = FullAttentionSpec(block_size=16, mock_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1, num_kv_heads=1,
head_size=64, head_size=64,
dtype=torch.float16, dtype=torch.float16)
use_mla=False)
mock_executor.get_kv_cache_specs.return_value = [{ mock_executor.get_kv_cache_specs.return_value = [{
"default": mock_spec "default": mock_spec

View File

@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
time.sleep(self._hand_shake_latency) time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by # These should've been done in register_kv_caches(), called by
# gpu_model_runner. Here we just hardcode some dummy values. # gpu_model_runner. Here we just hardcode some dummy values.
self.slot_size_bytes = 4096 slot_size_bytes = 4096
self.block_len = self.slot_size_bytes * self.block_size self.slot_size_per_layer = [slot_size_bytes]
self.block_len_per_layer = [slot_size_bytes * self.block_size]
self.num_blocks = 1 self.num_blocks = 1
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
num_blocks=1, num_blocks=1,
block_len=self.block_len, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
# `self.kv_cache_layout` is only forced to HND when vllm engine # `self.kv_cache_layout` is only forced to HND when vllm engine
# is started. We mock HND here. # is started. We mock HND here.
@ -485,8 +486,8 @@ class TestNixlHandshake:
worker = connector.connector_worker worker = connector.connector_worker
# Minimal local registration params used by add_remote_agent # Minimal local registration params used by add_remote_agent
worker.slot_size_bytes = 4096 worker.slot_size_per_layer = [4096]
worker.block_len = worker.slot_size_bytes * worker.block_size worker.block_len_per_layer = [4096 * worker.block_size]
worker.num_blocks = 1 worker.num_blocks = 1
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
@ -498,7 +499,7 @@ class TestNixlHandshake:
agent_metadata=FakeNixlWrapper.AGENT_METADATA, agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0], kv_caches_base_addr=[0],
num_blocks=1, num_blocks=1,
block_len=worker.block_len, block_lens=worker.block_len_per_layer,
attn_backend_name=worker.backend_name, attn_backend_name=worker.backend_name,
kv_cache_layout=mismatched_layout, kv_cache_layout=mismatched_layout,
) )

View File

@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
"target_attn_1": mock.MagicMock(), "target_attn_1": mock.MagicMock(),
"target_attn_2": mock.MagicMock() "target_attn_2": mock.MagicMock()
} }
target_indx_layers: dict[str, mock.MagicMock] = {}
# Draft model has one extra attention layer compared to target model # Draft model has one extra attention layer compared to target model
all_attn_layers = { all_attn_layers = {
**target_attn_layers, "draft_extra_attn": mock.MagicMock() **target_attn_layers, "draft_extra_attn": mock.MagicMock()
} }
all_indx_layers: dict[str, mock.MagicMock] = {}
# Make mock_get_layers return different values for each call # Make mock_get_layers return different values for each call
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers] mock_get_layers.side_effect = [
target_attn_layers, target_indx_layers, all_attn_layers,
all_indx_layers
]
# Setup mock for pp group to return the appropriate value for world size # Setup mock for pp group to return the appropriate value for world size
mock_pp_group = mock.MagicMock() mock_pp_group = mock.MagicMock()
@ -658,6 +664,9 @@ def test_propose_tree(spec_token_tree):
# Mock runner for attention metadata building. # Mock runner for attention metadata building.
proposer.runner = mock.MagicMock() proposer.runner = mock.MagicMock()
proposer.runner.attn_groups.append([mock.MagicMock()]) proposer.runner.attn_groups.append([mock.MagicMock()])
proposer.runner.attn_groups[0][0].metadata_builders = [
attn_metadata_builder
]
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
attn_metadata_builder attn_metadata_builder
proposer._get_attention_metadata_builder = mock.MagicMock( proposer._get_attention_metadata_builder = mock.MagicMock(

View File

@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest import mock
import pytest
import torch
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VllmConfig)
from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
"""Create an MTP proposer with unified model configuration."""
model_config = ModelConfig(model=mimo_7b_dir,
runner="generate",
max_model_len=100,
trust_remote_code=True)
speculative_config = SpeculativeConfig(
target_model_config=model_config,
target_parallel_config=ParallelConfig(),
model=mimo_7b_dir,
method="mtp",
num_speculative_tokens=num_speculative_tokens,
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
speculative_config=speculative_config,
device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(),
load_config=LoadConfig(),
scheduler_config=SchedulerConfig())
return EagleProposer(vllm_config=vllm_config,
device=current_platform.device_type)
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
mock_get_pp_group):
"""Test MTP-specific model loading with unified model approach."""
# Setup mocks
mock_model = mock.MagicMock()
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
mock_get_model.return_value = mock_model
target_attn_layers = {"target_attn_1": mock.MagicMock()}
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
target_indexer_layers: dict = {}
all_indexer_layers: dict = {}
mock_get_layers.side_effect = [
target_attn_layers, target_indexer_layers, all_attn_layers,
all_indexer_layers
]
mock_pp_group = mock.MagicMock()
mock_pp_group.world_size = 1
mock_get_pp_group.return_value = mock_pp_group
# Create target model
class _TargetModelStub(LlamaForCausalLM):
model: mock.MagicMock
lm_head: mock.MagicMock
target_model = mock.create_autospec(_TargetModelStub, instance=True)
target_model.model = mock.MagicMock()
target_model.model.embed_tokens.weight.shape = (131072, 4096)
target_model.lm_head = mock.MagicMock()
# Create MTP proposer
proposer = _create_mtp_proposer(num_speculative_tokens=4)
proposer.load_model(target_model)
# Verify MTP-specific behavior:
# Model is loaded
mock_get_model.assert_called_once()
# MTP shares lm_head with target model
assert proposer.model.lm_head == target_model.lm_head
# MTP shares embed_tokens with target model
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
@pytest.mark.parametrize("num_speculative_tokens", [1])
def test_mtp_propose(num_speculative_tokens, monkeypatch):
"""Test that MTP's forward method returns hidden states directly"""
device = torch.device(current_platform.device_type)
batch_size = 2
seq_lens = [5, 3]
total_tokens = sum(seq_lens)
vocab_size = 100
proposer = _create_mtp_proposer(num_speculative_tokens)
hidden_size = proposer.hidden_size
# Mock the MTP model to verify it returns hidden states directly
model_mock = mock.MagicMock()
# MTP returns hidden states directly
if num_speculative_tokens == 1:
model_mock.return_value = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
# Multiple forward passes for multi-token speculation
forward_returns = []
for i in range(num_speculative_tokens):
if i == 0:
h_states = torch.zeros(total_tokens,
hidden_size,
device=device)
else:
h_states = torch.zeros(batch_size, hidden_size, device=device)
forward_returns.append(h_states)
model_mock.side_effect = forward_returns
# Mock compute_logits
def create_deterministic_logits(batch_size, vocab_size, token_offset):
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
logits[:, token_offset] = 100.0
return logits
if num_speculative_tokens == 1:
model_mock.compute_logits.return_value = create_deterministic_logits(
batch_size, vocab_size, 42)
else:
logits_returns = [
create_deterministic_logits(batch_size, vocab_size, 42 + i)
for i in range(num_speculative_tokens)
]
model_mock.compute_logits.side_effect = logits_returns
proposer.model = model_mock
proposer.attn_layer_names = ["layer.0"]
# Prepare inputs
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
common_attn_metadata = create_common_attn_metadata(batch_spec,
block_size=16,
device=device)
target_token_ids = torch.randint(0,
vocab_size, (total_tokens, ),
device=device)
target_positions = torch.cat([
torch.arange(seq_lens[0], device=device),
torch.arange(seq_lens[1], device=device)
])
target_hidden_states = torch.randn(total_tokens,
hidden_size,
device=device)
next_token_ids = torch.randint(0,
vocab_size, (batch_size, ),
dtype=torch.int32,
device=device)
sampling_metadata = mock.MagicMock()
# Setup attention metadata
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
vllm_config=proposer.vllm_config,
device=device,
)
proposer.runner = mock.MagicMock()
proposer.attn_metadata_builder = attn_metadata_builder
# Run propose
result = proposer.propose(target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
# Verify the model was called correctly
assert model_mock.called
# Verify output shape
assert result.shape == (batch_size, num_speculative_tokens)

View File

@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner):
runner.parallel_config), runner.parallel_config),
head_size=runner.model_config.get_head_size(), head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype, dtype=runner.kv_cache_dtype,
use_mla=False,
) )
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig( kv_cache_config = KVCacheConfig(

View File

@ -1678,6 +1678,15 @@ def cp_gather_cache(src_cache: torch.Tensor,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
quant_block_size: int,
kv_cache_dtype: str) -> None:
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
quant_block_size,
kv_cache_dtype)
def get_device_attribute(attribute: int, device: int) -> int: def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)

View File

@ -70,6 +70,7 @@ class AttentionBackend(ABC):
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
raise NotImplementedError raise NotImplementedError

View File

@ -95,6 +95,7 @@ class Attention(nn.Module, AttentionLayerBase):
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None, per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False, use_mla: bool = False,
use_sparse: bool = False,
prefix: str = "", prefix: str = "",
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
@ -155,6 +156,7 @@ class Attention(nn.Module, AttentionLayerBase):
self._o_scale_float: Optional[float] = None self._o_scale_float: Optional[float] = None
self.use_mla = use_mla self.use_mla = use_mla
self.use_sparse = use_sparse
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
@ -187,7 +189,8 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype, kv_cache_dtype,
block_size, block_size,
use_mla=use_mla, use_mla=use_mla,
has_sink=self.has_sink) has_sink=self.has_sink,
use_sparse=use_sparse)
else: else:
self.attn_backend = attn_backend self.attn_backend = attn_backend

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from typing import List, Optional from typing import ClassVar, List, Optional
import torch import torch
@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend,
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches, AttentionCGSupport, CommonAttentionMetadata,
subclass_attention_backend) make_local_attention_virtual_batches, subclass_attention_backend)
from ..layer import Attention from ..layer import Attention
@ -28,6 +28,8 @@ def create_chunked_local_attention_backend(
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,

View File

@ -138,3 +138,208 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx) out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
out = cp_group.reduce_scatter(out, dim=1) out = cp_group.reduce_scatter(out, dim=1)
return out return out
@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
out_ptr, # [B, Lmax, D]
lengths_ptr, # *i32, [B]
N: tl.constexpr,
D: tl.constexpr,
Lmax: tl.constexpr,
PAD_VALUE: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# Compute start index and sequence length from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
# compute input row indices for valid (b, t)
in_row = in_start + off_t
valid_row = (off_t < seq_len) & t_mask
# Pointers
# x_ptr: row-major [N, D]
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
# out_ptr: row-major [B, Lmax, D]
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
None] * D + off_d[None, :]
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
d_mask = off_d[None, :] < D
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
# Load & write only where within seq_len
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
def pack_seq_triton(x: torch.Tensor,
lengths: torch.Tensor,
pad_value: float = -float('inf'),
block_t: int = 64,
block_d: int = 64) -> torch.Tensor:
"""
Pack sequences of different lengths into a batched tensor.
Args:
x: [N, ...] - input tensor where N is total number of tokens
lengths: [B] - sequence lengths for each batch
pad_value: value to use for padding
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
packed: [B, Lmax, ...] - packed tensor
"""
# Handle multi-dimensional input by reshaping to (N, -1)
original_shape = x.shape
if len(original_shape) > 2:
N = original_shape[0]
x_reshaped = x.reshape(N, -1)
D = x_reshaped.shape[1]
else:
N, D = x.shape
x_reshaped = x
B = lengths.numel()
Lmax = int(lengths.max().item())
# Starts are computed inside the kernel from lengths
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_pack_seq_kernel[grid](x_reshaped,
out,
lengths.int(),
N,
D,
Lmax,
PAD_VALUE=float(pad_value),
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 2:
output_shape = (B, Lmax) + original_shape[1:]
out = out.reshape(output_shape)
return out
@triton.jit
def _unpack_seq_triton_kernel(
packed_ptr, # [B, Lmax, D]
out_ptr, # [N, D]
lengths_ptr, # *i32, [B]
B: tl.constexpr,
Lmax: tl.constexpr,
D: tl.constexpr,
BLOCK_T: tl.constexpr, # timesteps per program
BLOCK_D: tl.constexpr # features per program
):
pid_b = tl.program_id(0) # batch id
pid_t = tl.program_id(1) # block over time dimension
pid_d = tl.program_id(2) # block over feature dimension
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
# bounds: compute start from cumulative lengths
in_start = 0
for i in range(pid_b):
in_start += tl.load(lengths_ptr + i)
seq_len = tl.load(lengths_ptr + pid_b)
# valid time positions for this block
t_mask = off_t < Lmax
valid_row = (off_t < seq_len) & t_mask
# compute output row indices for valid (b, t)
out_row = in_start + off_t
# Pointers
# packed_ptr: row-major [B, Lmax, D]
packed_row_ptr = packed_ptr + (pid_b * Lmax +
off_t)[:, None] * D + off_d[None, :]
# out_ptr: row-major [N, D]
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
# Load from packed tensor and store to output
d_mask = off_d[None, :] < D
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
def unpack_seq_triton(packed_tensor: torch.Tensor,
lengths: torch.Tensor,
block_t: int = 64,
block_d: int = 64) -> torch.Tensor:
"""
Unpack a packed decode query tensor back to the original format.
Efficient Triton implementation.
Args:
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
lengths: [B] - sequence lengths for each batch
block_t: block size for time dimension
block_d: block size for feature dimension
Returns:
unpacked_tensor: [N, ...] where N = sum(lengths)
"""
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
original_shape = packed_tensor.shape
if len(original_shape) > 3:
B, Lmax = original_shape[:2]
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
D = packed_reshaped.shape[2]
else:
B, Lmax, D = packed_tensor.shape
packed_reshaped = packed_tensor
# Calculate total number of elements
N = int(lengths.sum().item())
out = torch.empty((N, D),
device=packed_tensor.device,
dtype=packed_tensor.dtype)
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
_unpack_seq_triton_kernel[grid](packed_reshaped,
out,
lengths.int(),
B,
Lmax,
D,
BLOCK_T=block_t,
BLOCK_D=block_d,
num_warps=4,
num_stages=2)
# Reshape output back to original dimensions (except first dimension)
if len(original_shape) > 3:
output_shape = (N, ) + original_shape[2:]
out = out.reshape(output_shape)
return out

View File

@ -19,6 +19,15 @@ if current_platform.is_cuda():
else: else:
_flashmla_C_AVAILABLE = False _flashmla_C_AVAILABLE = False
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
else:
_flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]: def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
""" """
@ -37,24 +46,34 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
def get_mla_metadata( def get_mla_metadata(
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
num_heads_per_head_k: int, num_q_tokens_per_head_k: int,
num_heads_k: int, num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]: num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
cache_seqlens: (batch_size), dtype torch.int32. - cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. - num_q_tokens_per_head_k:
num_heads_k: num_heads_k. Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Return: Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - tile_scheduler_metadata:
dtype torch.int32. (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32.
""" """
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, return torch.ops._flashmla_C.get_mla_decoding_metadata(
num_heads_per_head_k, cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
num_heads_k) is_fp8_kvcache, topk)
def flash_mla_with_kvcache( def flash_mla_with_kvcache(
@ -69,45 +88,95 @@ def flash_mla_with_kvcache(
causal: bool = False, causal: bool = False,
descale_q: Optional[torch.Tensor] = None, descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim). - q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32. - cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head_dim of v. - head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - tile_scheduler_metadata:
torch.int32, return by get_mla_metadata. (num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. returned by get_mla_metadata.
softmax_scale: float. The scaling of QK^T before applying softmax. - num_splits:
Default to 1 / sqrt(head_dim). (batch_size + 1), torch.int32, returned by get_mla_metadata.
causal: bool. Whether to apply causal attention mask. - softmax_scale: float.
descale_q: (batch_size), torch.float32. Descaling factors for Q. The scale of QK^T before applying softmax.
descale_k: (batch_size), torch.float32. Descaling factors for K. Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Return: Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
""" """
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( if indices is not None:
q, # NOTE (zyongye): sparse attention is also causal
k_cache, # since it only attend to the tokens before
head_dim_v, # but here `causal` should not be specified
cache_seqlens, assert not causal, \
block_table, "causal must be `false` if sparse attention is enabled."
softmax_scale, assert (descale_q is None) == (
causal, descale_k is None
tile_scheduler_metadata, ), "descale_q and descale_k should be both None or both not None"
num_splits,
descale_q,
descale_k,
)
# Note(hc): need revisit when we support DCP with decode query_len > 1. if indices is None and q.element_size() == 1:
return out.squeeze(1), softmax_lse.squeeze(-1) out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
indices)
return out, softmax_lse
def flash_mla_sparse_prefill(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
sm_scale, d_v)
return results
# #

View File

@ -50,6 +50,7 @@ class PagedAttention:
block_size: int, block_size: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size) return (2, num_blocks, block_size * num_kv_heads * head_size)

View File

@ -144,6 +144,7 @@ def get_attn_backend(
block_size: int, block_size: int,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it.""" """Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong # Accessing envs.* behind an @lru_cache decorator can cause the wrong
@ -158,6 +159,7 @@ def get_attn_backend(
use_v1=envs.VLLM_USE_V1, use_v1=envs.VLLM_USE_V1,
use_mla=use_mla, use_mla=use_mla,
has_sink=has_sink, has_sink=has_sink,
use_sparse=use_sparse,
) )
@ -170,6 +172,7 @@ def _cached_get_attn_backend(
use_v1: bool = False, use_v1: bool = False,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
# Check whether a particular choice of backend was # Check whether a particular choice of backend was
@ -203,7 +206,7 @@ def _cached_get_attn_backend(
# get device-specific attn_backend # get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls( attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
use_mla, has_sink) use_mla, has_sink, use_sparse)
if not attention_cls: if not attention_cls:
raise ValueError( raise ValueError(
f"Invalid attention backend for {current_platform.device_name}") f"Invalid attention backend for {current_platform.device_name}")

View File

@ -22,7 +22,8 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128] BlockSize = Literal[1, 8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2",
"fp8_inc"]
MambaDType = Literal["auto", "float32"] MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
@ -52,7 +53,11 @@ class CacheConfig:
cache_dtype: CacheDType = "auto" cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type. """Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
bfloat16 instead, this is an invalid option for models that do not default
to fp8.
"""
is_attention_free: bool = False is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in """Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here.""" `ModelConfig` and that value should be manually duplicated here."""
@ -171,11 +176,12 @@ class CacheConfig:
if self.cache_dtype == "auto": if self.cache_dtype == "auto":
pass pass
elif self.cache_dtype in get_args(CacheDType): elif self.cache_dtype in get_args(CacheDType):
logger.info( if self.cache_dtype.startswith("fp8"):
"Using fp8 data type to store kv cache. It reduces the GPU " logger.info(
"memory footprint and boosts the performance. " "Using fp8 data type to store kv cache. It reduces the GPU "
"Meanwhile, it may cause accuracy drop without a proper " "memory footprint and boosts the performance. "
"scaling factor.") "Meanwhile, it may cause accuracy drop without a proper "
"scaling factor.")
else: else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

View File

@ -360,6 +360,7 @@ class CompilationConfig:
"vllm.linear_attention", "vllm.linear_attention",
"vllm.plamo2_mamba_mixer", "vllm.plamo2_mamba_mixer",
"vllm.gdn_attention", "vllm.gdn_attention",
"vllm.sparse_attn_indexer",
] ]
def compute_hash(self) -> str: def compute_hash(self) -> str:

View File

@ -137,6 +137,9 @@ class ModelConfig:
"""Allowing API requests to read local images or videos from directories """Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only specified by the server file system. This is a security risk. Should only
be enabled in trusted environments.""" be enabled in trusted environments."""
allowed_media_domains: Optional[list[str]] = None
"""If set, only media URLs that belong to this domain can be used for
multi-modal inputs. """
revision: Optional[str] = None revision: Optional[str] = None
"""The specific model version to use. It can be a branch name, a tag name, """The specific model version to use. It can be a branch name, a tag name,
or a commit id. If unspecified, will use the default version.""" or a commit id. If unspecified, will use the default version."""
@ -1074,14 +1077,14 @@ class ModelConfig:
if not hasattr(self.hf_text_config, "model_type"): if not hasattr(self.hf_text_config, "model_type"):
return False return False
elif self.hf_text_config.model_type in \ elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', ('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp',
'kimi_k2', 'longcat_flash'): 'kimi_k2', 'longcat_flash'):
return self.hf_text_config.kv_lora_rank is not None return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle': elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the # if the model is an EAGLE module, check for the
# underlying architecture # underlying architecture
return self.hf_text_config.model.model_type in \ return self.hf_text_config.model.model_type in \
('deepseek_v2', 'deepseek_v3') \ ('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
and self.hf_text_config.kv_lora_rank is not None and self.hf_text_config.kv_lora_rank is not None
return False return False

View File

@ -279,6 +279,24 @@ class ParallelConfig:
assert last_exc is not None assert last_exc is not None
raise last_exc raise last_exc
# The all_reduce at the end of attention (during o_proj) means that
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (envs.VLLM_ALL2ALL_BACKEND
in ("allgather_reducescatter", "naive",
"deepep_high_throughput", "deepep_low_latency")
and self.enable_expert_parallel
and self.tensor_parallel_size > 1
and self.data_parallel_size > 1)
@staticmethod @staticmethod
def has_unfinished_dp(dp_group: ProcessGroup, def has_unfinished_dp(dp_group: ProcessGroup,
has_unfinished: bool) -> bool: has_unfinished: bool) -> bool:

View File

@ -32,14 +32,17 @@ logger = init_logger(__name__)
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp", "mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp", "ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"] "longcat_flash_mtp", "mtp"]
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
"qwen3_next_mtp", "longcat_flash_mtp")
@config @config
@dataclass @dataclass
class SpeculativeConfig: class SpeculativeConfig:
"""Configuration for speculative decoding.""" """Configuration for speculative decoding."""
enforce_eager: Optional[bool] = None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control # General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the """The number of speculative tokens, if provided. It will default to the
@ -143,7 +146,7 @@ class SpeculativeConfig:
@staticmethod @staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type == "deepseek_v3": if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp" hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp": if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None) n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
@ -207,11 +210,21 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by # can not be detected, it will be considered as the "draft_model" by
# default. # default.
if self.method in MTP_MODEL_TYPES:
logger.warning("method `%s` is deprecated and replaced with mtp.",
self.method)
self.method = "mtp"
if self.model is None and self.num_speculative_tokens is not None: if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting if self.method == "mtp":
if (self.target_model_config assert (
and self.target_model_config.hf_text_config.model_type self.target_model_config
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")): is not None), "target_model_config must be present for mtp"
if self.target_model_config.hf_text_config.model_type \
== "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# remove this when the issue is fixed.
self.enforce_eager = True
# use the draft model from the same model: # use the draft model from the same model:
self.model = self.target_model_config.model self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as # Align the quantization of draft model for cases such as
@ -281,6 +294,8 @@ class SpeculativeConfig:
trust_remote_code, trust_remote_code,
allowed_local_media_path=self.target_model_config. allowed_local_media_path=self.target_model_config.
allowed_local_media_path, allowed_local_media_path,
allowed_media_domains=self.target_model_config.
allowed_media_domains,
dtype=self.target_model_config.dtype, dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed, seed=self.target_model_config.seed,
revision=self.revision, revision=self.revision,
@ -312,31 +327,13 @@ class SpeculativeConfig:
"mlp_speculator"): "mlp_speculator"):
self.method = "mlp_speculator" self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): in MTP_MODEL_TYPES):
self.method = "deepseek_mtp" self.method = "mtp"
if self.num_speculative_tokens > 1: if self.num_speculative_tokens > 1:
logger.warning( logger.warning(
"All Deepseek MTP models only have " \ "Enabling num_speculative_tokens > 1 will run" \
"one layer. Might need some code changes " \ "multiple times of forward on same MTP layer" \
"to support multiple layers." ",which may result in lower acceptance rate" \
)
elif (self.draft_model_config.hf_config.model_type ==
"ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
) )
elif (self.draft_model_config.hf_config.model_type elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")): in ("longcat_flash_mtp")):
@ -353,7 +350,7 @@ class SpeculativeConfig:
"Speculative decoding with draft model is not " "Speculative decoding with draft model is not "
"supported yet. Please consider using other " "supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, " "speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.") "eagle, or mtp.")
# Replace hf_config for EAGLE draft_model # Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"): if self.method in ("eagle", "eagle3"):
@ -562,8 +559,7 @@ class SpeculativeConfig:
return self.num_speculative_tokens return self.num_speculative_tokens
def use_eagle(self) -> bool: def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp", return self.method in ("eagle", "eagle3", "mtp")
"qwen3_next_mtp", "longcat_flash_mtp")
def __repr__(self) -> str: def __repr__(self) -> str:
method = self.method method = self.method

View File

@ -54,6 +54,7 @@ class HTTPConnection:
stream: bool = False, stream: bool = False,
timeout: Optional[float] = None, timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None, extra_headers: Optional[Mapping[str, str]] = None,
allow_redirects: bool = True,
): ):
self._validate_http_url(url) self._validate_http_url(url)
@ -63,7 +64,8 @@ class HTTPConnection:
return client.get(url, return client.get(url,
headers=self._headers(**extra_headers), headers=self._headers(**extra_headers),
stream=stream, stream=stream,
timeout=timeout) timeout=timeout,
allow_redirects=allow_redirects)
async def get_async_response( async def get_async_response(
self, self,
@ -71,6 +73,7 @@ class HTTPConnection:
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None, extra_headers: Optional[Mapping[str, str]] = None,
allow_redirects: bool = True,
): ):
self._validate_http_url(url) self._validate_http_url(url)
@ -79,10 +82,17 @@ class HTTPConnection:
return client.get(url, return client.get(url,
headers=self._headers(**extra_headers), headers=self._headers(**extra_headers),
timeout=timeout) timeout=timeout,
allow_redirects=allow_redirects)
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: def get_bytes(self,
with self.get_response(url, timeout=timeout) as r: url: str,
*,
timeout: Optional[float] = None,
allow_redirects: bool = True) -> bytes:
with self.get_response(url,
timeout=timeout,
allow_redirects=allow_redirects) as r:
r.raise_for_status() r.raise_for_status()
return r.content return r.content
@ -92,8 +102,10 @@ class HTTPConnection:
url: str, url: str,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
allow_redirects: bool = True,
) -> bytes: ) -> bytes:
async with await self.get_async_response(url, timeout=timeout) as r: async with await self.get_async_response(
url, timeout=timeout, allow_redirects=allow_redirects) as r:
r.raise_for_status() r.raise_for_status()
return await r.read() return await r.read()

View File

@ -6,7 +6,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import get_dp_group from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx from vllm.utils import has_deep_ep, has_pplx
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
super().__init__(cpu_group) super().__init__(cpu_group)
def naive_multicast(self, x: torch.Tensor, def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor): cu_tokens_across_sp_cpu: torch.Tensor,
is_sequence_parallel: bool) -> torch.Tensor:
assert (len(x.shape) == 2) assert (len(x.shape) == 2)
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
device=x.device, device=x.device,
dtype=x.dtype) dtype=x.dtype)
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ rank = self.rank if is_sequence_parallel else self.dp_rank
self.dp_rank - 1] world_size = (self.world_size
end = cu_tokens_across_dp_cpu[self.dp_rank] if is_sequence_parallel else self.dp_world_size)
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
end = cu_tokens_across_sp_cpu[rank]
buffer[start:end, :].copy_(x) buffer[start:end, :].copy_(x)
for idx in range(self.dp_world_size): for idx in range(world_size):
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
end = cu_tokens_across_dp_cpu[idx] end = cu_tokens_across_sp_cpu[idx]
self.dp_group.broadcast(buffer[start:end, :], idx) get_ep_group().broadcast(buffer[start:end, :], idx)
return buffer return buffer
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
sizes = get_forward_context( hidden_states: torch.Tensor,
).dp_metadata.get_chunk_sizes_across_dp_rank() router_logits: torch.Tensor,
hidden_states, router_logits = get_dp_group().all_gatherv( is_sequence_parallel: bool = False
[hidden_states, router_logits], ) -> tuple[torch.Tensor, torch.Tensor]:
dim=0, sp_size = self.tp_group.world_size if is_sequence_parallel else 1
sizes=sizes, dp_metadata = get_forward_context().dp_metadata
) cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_sp_cpu,
is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
sizes = get_forward_context( hidden_states: torch.Tensor,
).dp_metadata.get_chunk_sizes_across_dp_rank() is_sequence_parallel: bool = False) -> torch.Tensor:
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0, ep_rank = self.rank if is_sequence_parallel else self.dp_rank
sizes=sizes)
dp_metadata = get_forward_context().dp_metadata
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
end = cu_tokens_across_sp_cpu[ep_rank]
all_hidden_states = get_ep_group().all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
return hidden_states return hidden_states
def destroy(self): def destroy(self):
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
def __init__(self, cpu_group): def __init__(self, cpu_group):
super().__init__(cpu_group) super().__init__(cpu_group)
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Gather hidden_states and router_logits from all dp ranks. Gather hidden_states and router_logits from all dp ranks.
""" """
sizes = get_forward_context( sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank() ).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits], [hidden_states, router_logits],
dim=0, dim=0,
sizes=sizes, sizes=sizes,
) )
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
""" """
Reduce-scatter hidden_states across all dp ranks. Reduce-scatter hidden_states across all dp ranks.
""" """
sizes = get_forward_context( sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank() ).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0, dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
sizes=sizes) hidden_states = dist_group.reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states return hidden_states
def destroy(self): def destroy(self):
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
kwargs, pplx.AllToAll.internode kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode) if self.internode else pplx.AllToAll.intranode)
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
def dispatch(self, hidden_states: torch.Tensor, def dispatch(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError raise NotImplementedError
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
self.workspace_tensor = None self.workspace_tensor = None
self.prepare_workspace_tensor = None self.prepare_workspace_tensor = None
self.mapping = None self.mapping = None
self.initialized = False self.initialized = False

View File

@ -28,6 +28,8 @@ class Cache:
class All2AllManagerBase: class All2AllManagerBase:
rank: int
world_size: int
def __init__(self, cpu_group): def __init__(self, cpu_group):
self.cpu_group = cpu_group self.cpu_group = cpu_group
@ -40,6 +42,7 @@ class All2AllManagerBase:
# all2all lives in ep group, which is merged from dp and tp group # all2all lives in ep group, which is merged from dp and tp group
self.dp_group = get_dp_group() self.dp_group = get_dp_group()
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
# no self.ep_group since self.ep_group is still in construction # no self.ep_group since self.ep_group is still in construction
# when we create this object # when we create this object
self.dp_rank = self.dp_group.rank_in_group self.dp_rank = self.dp_group.rank_in_group
@ -60,17 +63,21 @@ class All2AllManagerBase:
# and reuse it for the same config. # and reuse it for the same config.
raise NotImplementedError raise NotImplementedError
def dispatch(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False):
raise NotImplementedError
def set_num_sms(self, num_sms: int): def set_num_sms(self, num_sms: int):
pass pass
def max_sms_used(self) -> Optional[int]: def max_sms_used(self) -> Optional[int]:
return None # None means it could use the whole GPU return None # None means it could use the whole GPU
def dispatch(self, hidden_states: torch.Tensor, def combine(self,
router_logits: torch.Tensor): hidden_states: torch.Tensor,
raise NotImplementedError is_sequence_parallel: bool = False):
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def destroy(self): def destroy(self):
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
module.quant_method.init_prepare_finalize(module) module.quant_method.init_prepare_finalize(module)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Dispatch the hidden states and router logits to the appropriate device. Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.
""" """
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
""" """
Combine the hidden states and router logits from the appropriate device. Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class. This is a no-op in the base class.

View File

@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
# ep does not use pynccl
use_pynccl = "ep" not in unique_name
self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem self.use_torch_symm_mem = use_torch_symm_mem
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
SymmMemCommunicator) SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1: if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator( self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group, group=self.cpu_group,
device=self.device, device=self.device,
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list return output_list
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits) hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states) hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states return hidden_states

View File

@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
dist.broadcast(input_, src=src, group=self.device_group) dist.broadcast(input_, src=src, group=self.device_group)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch( hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits) hidden_states, router_logits, is_sequence_parallel)
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: def combine(self,
hidden_states: torch.Tensor,
is_sequence_parallel: bool = False) -> torch.Tensor:
assert self.all2all_manager is not None assert self.all2all_manager is not None
hidden_states = self.all2all_manager.combine(hidden_states) hidden_states = self.all2all_manager.combine(hidden_states,
is_sequence_parallel)
return hidden_states return hidden_states

View File

@ -84,7 +84,7 @@ class NixlAgentMetadata(
agent_metadata: bytes agent_metadata: bytes
kv_caches_base_addr: list[int] kv_caches_base_addr: list[int]
num_blocks: int num_blocks: int
block_len: int block_lens: list[int]
attn_backend_name: str attn_backend_name: str
kv_cache_layout: str kv_cache_layout: str
@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {} self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
def add_new_req( def add_new_req(
self, self,
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time # Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {} self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, request: "Request", self, request: "Request",
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
if not params: if not params:
return return
if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"): if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl, # NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer. # prefilled blocks need to be saved to host memory before transfer.
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
request_id=req_id, request_id=req_id,
local_block_ids=block_ids, local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params,
load_remote_cache=True,
save_to_host=False,
) )
for req_id, (req, block_ids) in self._reqs_need_save.items(): for req_id, (req, block_ids) in self._reqs_need_save.items():
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
) )
meta.reqs_to_send = self._reqs_need_send meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
# Clear the list once workers start the transfers # Clear the list once workers start the transfers
self._reqs_need_recv.clear() self._reqs_need_recv.clear()
self._reqs_need_save.clear() self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_need_send = {} self._reqs_need_send = {}
return meta return meta
@ -465,8 +474,11 @@ class NixlConnectorWorker:
"backends", ["UCX"]) "backends", ["UCX"])
# Agent. # Agent.
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"] non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
config = nixl_agent_config(backends=self.nixl_backends) if len( if nixl_agent_config is None:
non_ucx_backends) > 0 and nixl_agent_config is not None else None config = None
else:
config = nixl_agent_config(backends=self.nixl_backends) if len(
non_ucx_backends) > 0 else nixl_agent_config(num_threads=8)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config) self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
@ -546,6 +558,8 @@ class NixlConnectorWorker:
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Track the expiration time of requests that are waiting to be sent. # Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {} self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()
# Background thread for handling new handshake requests. # Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None self._nixl_handshake_listener_t: Optional[threading.Thread] = None
@ -752,6 +766,9 @@ class NixlConnectorWorker:
split_k_and_v = not (self.use_mla or self._use_pallas split_k_and_v = not (self.use_mla or self._use_pallas
or self._use_flashinfer) or self._use_flashinfer)
tensor_size_bytes = None tensor_size_bytes = None
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items(): for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [ cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches cache_or_caches
@ -769,10 +786,25 @@ class NixlConnectorWorker:
tensor_size_bytes = curr_tensor_size_bytes tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0] self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \ assert cache.shape[0] == self.num_blocks, \
"All kv cache tensors must have the same size" "All kv cache tensors must have the same number of blocks"
self.block_len_per_layer.append(curr_tensor_size_bytes //
self.num_blocks)
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
self.block_size)
if not self.use_mla:
# Different kv cache shape is not supported by HeteroTP
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append( caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, "")) (base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
logger.debug("Different block lengths collected: %s",
set(self.block_len_per_layer))
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data) self.num_regions = len(caches_data)
@ -785,16 +817,12 @@ class NixlConnectorWorker:
logger.debug("Done registering descs") logger.debug("Done registering descs")
self._registered_descs.append(descs) self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
self.device_kv_caches = kv_caches self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks self.dst_num_blocks[self.engine_id] = self.num_blocks
if self._use_flashinfer: if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0 for i in range(len(self.slot_size_per_layer)):
self.slot_size_bytes /= 2 assert self.slot_size_per_layer[i] % 2 == 0
self.slot_size_per_layer[i] //= 2
# NOTE (NickLucche) When FlashInfer is used, memory is registered # NOTE (NickLucche) When FlashInfer is used, memory is registered
# with joint KV for each block. This minimizes the overhead in # with joint KV for each block. This minimizes the overhead in
@ -804,17 +832,17 @@ class NixlConnectorWorker:
# of 'virtual' regions here and halve `block_len` below. # of 'virtual' regions here and halve `block_len` below.
self.num_regions *= 2 self.num_regions *= 2
kv_block_len = self.get_backend_aware_kv_block_len()
# Register local/src descr for NIXL xfer. # Register local/src descr for NIXL xfer.
blocks_data = [] blocks_data = []
for base_addr in seen_base_addresses: for i, base_addr in enumerate(seen_base_addresses):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
# NOTE With heter-TP, more blocks are prepared than what are # NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to # could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for # select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean. # local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks): for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset addr = base_addr + block_offset
# (addr, len, device id) # (addr, len, device id)
blocks_data.append((addr, kv_block_len, self.tp_rank)) blocks_data.append((addr, kv_block_len, self.tp_rank))
@ -824,7 +852,7 @@ class NixlConnectorWorker:
# descs ordering. This is needed for selecting contiguous heads # descs ordering. This is needed for selecting contiguous heads
# when split across TP ranks. # when split across TP ranks.
for block_id in range(self.num_blocks): for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len block_offset = block_id * self.block_len_per_layer[i]
addr = base_addr + block_offset addr = base_addr + block_offset
# Register addresses for V cache (K registered first). # Register addresses for V cache (K registered first).
v_addr = addr + kv_block_len v_addr = addr + kv_block_len
@ -864,7 +892,7 @@ class NixlConnectorWorker:
agent_metadata=self.nixl_wrapper.get_agent_metadata(), agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
block_len=self.block_len, block_lens=self.block_len_per_layer,
attn_backend_name=self.backend_name, attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout) kv_cache_layout=self.kv_cache_layout)
ready_event = threading.Event() ready_event = threading.Event()
@ -889,7 +917,7 @@ class NixlConnectorWorker:
The latter, assuming D.world_size > P.world_size, requires that two or The latter, assuming D.world_size > P.world_size, requires that two or
more local TP worker share the xfer from a single TP worker. more local TP worker share the xfer from a single TP worker.
Here's an example: Here's an example (non-MLA case):
rank_offset p_remote_tp_rank rank_offset p_remote_tp_rank
(kv split no) (kv split no)
@ -945,14 +973,20 @@ class NixlConnectorWorker:
total_num_kv_heads = self.model_config.get_total_num_kv_heads() total_num_kv_heads = self.model_config.get_total_num_kv_heads()
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1 is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or is_kv_replicated: if self.use_mla or is_kv_replicated:
# With MLA the only difference is in the number of blocks. # With replicated KV cache, only the number of blocks can differ.
remote_block_size = nixl_agent_meta.block_len // ( assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
self.slot_size_bytes) "KV cache sizes must match between P and D when replicated"
assert self.block_len == nixl_agent_meta.block_len remote_block_size = remote_block_len // (
self.slot_size_per_layer[0])
else: else:
remote_block_size = nixl_agent_meta.block_len // ( # When MLA is not used, this is a list of the same block length
self.slot_size_bytes * tp_ratio) for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, \
"All remote layers must have the same block size"
remote_block_size = remote_block_len // (
self.slot_size_per_layer[0] * tp_ratio)
if self._use_flashinfer: if self._use_flashinfer:
# With flashinfer, KV are sent in the same message. # With flashinfer, KV are sent in the same message.
remote_block_size //= 2 remote_block_size //= 2
@ -963,14 +997,14 @@ class NixlConnectorWorker:
raise ValueError( raise ValueError(
"Heterogeneous TP is not supported on XPU") "Heterogeneous TP is not supported on XPU")
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, " "Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype." "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
) )
assert self.block_size == remote_block_size, ( assert self.block_size == remote_block_size, (
"Remote P worker with different block size is not supported " "Remote P worker with different page/block size is not supported "
f"{self.block_size=} {remote_block_size=}") f"{self.block_size=}, {remote_block_size=}")
# Create dst descs and xfer side handles. TP workers have same #blocks. # Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks: if engine_id in self.dst_num_blocks:
@ -985,13 +1019,16 @@ class NixlConnectorWorker:
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
self.kv_caches_base_addr[ self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr engine_id] = nixl_agent_meta.kv_caches_base_addr
kv_block_len = self.get_backend_aware_kv_block_len()
rank_offset = self.tp_rank % tp_ratio * kv_block_len \ assert len(nixl_agent_meta.kv_caches_base_addr) == len(
if not (self.use_mla or is_kv_replicated) else 0 self.block_len_per_layer)
# Register all remote blocks, but only the corresponding kv heads. # Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr: for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
if not (self.use_mla or is_kv_replicated) else 0
for block_id in range(nixl_agent_meta.num_blocks): for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len block_offset = block_id * nixl_agent_meta.block_lens[i]
# For each block, grab the heads chunk belonging to rank_i # For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to # of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes. # self.block_len == remote_block_len//tp_ratio bytes.
@ -1002,9 +1039,9 @@ class NixlConnectorWorker:
if self._use_flashinfer: if self._use_flashinfer:
# With FlashInfer index V separately to allow head splitting. # With FlashInfer index V separately to allow head splitting.
for block_id in range(nixl_agent_meta.num_blocks): for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len block_offset = block_id * nixl_agent_meta.block_lens[i]
addr = base_addr + block_offset + rank_offset addr = base_addr + block_offset + rank_offset
v_addr = addr + nixl_agent_meta.block_len // 2 v_addr = addr + nixl_agent_meta.block_lens[i] // 2
blocks_data.append((v_addr, kv_block_len, remote_tp_rank)) blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
logger.debug( logger.debug(
@ -1082,6 +1119,7 @@ class NixlConnectorWorker:
"Releasing expired KV blocks for request %s which were " "Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.", req_id, "retrieved by %d decode worker(s) within %d seconds.", req_id,
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT) count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
self._reqs_to_process.remove(req_id)
del self._reqs_to_send[req_id] del self._reqs_to_send[req_id]
done_sending.add(req_id) done_sending.add(req_id)
@ -1097,7 +1135,8 @@ class NixlConnectorWorker:
for notifs in self.nixl_wrapper.get_new_notifs().values(): for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs: for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1) req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
if req_id not in self._reqs_to_send: if (req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process):
logger.error( logger.error(
"Potentially invalid KV blocks for " "Potentially invalid KV blocks for "
"unrecognized request %s were retrieved by " "unrecognized request %s were retrieved by "
@ -1110,7 +1149,8 @@ class NixlConnectorWorker:
tp_ratio): tp_ratio):
notified_req_ids.add(req_id) notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id] del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id] self._reqs_to_process.remove(req_id)
self._reqs_to_send.pop(req_id, None)
return notified_req_ids return notified_req_ids
def _pop_done_transfers( def _pop_done_transfers(
@ -1171,8 +1211,19 @@ class NixlConnectorWorker:
while not self._ready_requests.empty(): while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait()) self._read_blocks_for_req(*self._ready_requests.get_nowait())
# Keep around the requests that have been part of a batch. This is
# needed because async scheduling pushes the misalignment between the
# moment in which requests expiration is set (P side) and the moment in
# which blocks are read from D. As P can now more easily lag behind D
# while processing the next batch, we make sure to only set an
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)
# Add to requests that are waiting to be read and track expiration. # Add to requests that are waiting to be read and track expiration.
self._reqs_to_send.update(metadata.reqs_to_send) for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug( logger.debug(
@ -1317,7 +1368,7 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten() return descs_ids.flatten()
def get_backend_aware_kv_block_len(self): def get_backend_aware_kv_block_len(self, layer_idx: int):
""" """
Get the block length for one K/V element (K and V have the same size). Get the block length for one K/V element (K and V have the same size).
@ -1328,9 +1379,9 @@ class NixlConnectorWorker:
""" """
if self._use_flashinfer: if self._use_flashinfer:
# For indexing only half (either just the K or V part). # For indexing only half (either just the K or V part).
block_len = self.block_len // 2 block_len = self.block_len_per_layer[layer_idx] // 2
else: else:
block_len = self.block_len block_len = self.block_len_per_layer[layer_idx]
return block_len return block_len
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]: def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:

View File

@ -871,17 +871,24 @@ class GroupCoordinator:
model) model)
def dispatch( def dispatch(
self, hidden_states: torch.Tensor, self,
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.dispatch(hidden_states, return self.device_communicator.dispatch(hidden_states,
router_logits) router_logits,
is_sequence_parallel)
else: else:
return hidden_states, router_logits return hidden_states, router_logits
def combine(self, hidden_states) -> torch.Tensor: def combine(self,
hidden_states,
is_sequence_parallel: bool = False) -> torch.Tensor:
if self.device_communicator is not None: if self.device_communicator is not None:
return self.device_communicator.combine(hidden_states) return self.device_communicator.combine(hidden_states,
is_sequence_parallel)
else: else:
return hidden_states return hidden_states

View File

@ -297,6 +297,8 @@ class EngineArgs:
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
allowed_media_domains: Optional[
list[str]] = ModelConfig.allowed_media_domains
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
@ -531,6 +533,8 @@ class EngineArgs:
**model_kwargs["hf_config_path"]) **model_kwargs["hf_config_path"])
model_group.add_argument("--allowed-local-media-path", model_group.add_argument("--allowed-local-media-path",
**model_kwargs["allowed_local_media_path"]) **model_kwargs["allowed_local_media_path"])
model_group.add_argument("--allowed-media-domains",
**model_kwargs["allowed_media_domains"])
model_group.add_argument("--revision", **model_kwargs["revision"]) model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision", model_group.add_argument("--code-revision",
**model_kwargs["code_revision"]) **model_kwargs["code_revision"])
@ -997,6 +1001,7 @@ class EngineArgs:
tokenizer_mode=self.tokenizer_mode, tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path, allowed_local_media_path=self.allowed_local_media_path,
allowed_media_domains=self.allowed_media_domains,
dtype=self.dtype, dtype=self.dtype,
seed=self.seed, seed=self.seed,
revision=self.revision, revision=self.revision,
@ -1481,7 +1486,7 @@ class EngineArgs:
raise NotImplementedError( raise NotImplementedError(
"Draft model speculative decoding is not supported yet. " "Draft model speculative decoding is not supported yet. "
"Please consider using other speculative decoding methods " "Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.") "such as ngram, medusa, eagle, or mtp.")
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN", "FLASH_ATTN",

View File

@ -11,7 +11,12 @@ from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union, from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast) cast)
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable # yapf: enable
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid, supports_kw
logger = init_logger(__name__) logger = init_logger(__name__)
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def allowed_local_media_path(self): def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path return self._model_config.allowed_local_media_path
@property
def allowed_media_domains(self):
return self._model_config.allowed_media_domains
@property @property
def mm_registry(self): def mm_registry(self):
return MULTIMODAL_REGISTRY return MULTIMODAL_REGISTRY
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
) )
def parse_image( def parse_image(
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=media_io_kwargs, media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
) )
def parse_image( def parse_image(
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}
def apply_hf_chat_template( def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
) )
try: try:
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template, chat_template=hf_chat_template,
tokenize=tokenize, tokenize=tokenize,
**kwargs, **resolved_kwargs,
) )
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's

View File

@ -86,6 +86,8 @@ class LLM:
or videos from directories specified by the server file system. or videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted This is a security risk. Should only be enabled in trusted
environments. environments.
allowed_media_domains: If set, only media URLs that belong to this
domain can be used for multi-modal inputs.
tensor_parallel_size: The number of GPUs to use for distributed tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism. execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently, dtype: The data type for the model weights and activations. Currently,
@ -169,6 +171,7 @@ class LLM:
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: ModelDType = "auto", dtype: ModelDType = "auto",
quantization: Optional[QuantizationMethods] = None, quantization: Optional[QuantizationMethods] = None,
@ -264,6 +267,7 @@ class LLM:
skip_tokenizer_init=skip_tokenizer_init, skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
allowed_local_media_path=allowed_local_media_path, allowed_local_media_path=allowed_local_media_path,
allowed_media_domains=allowed_media_domains,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
dtype=dtype, dtype=dtype,
quantization=quantization, quantization=quantization,

View File

@ -3,12 +3,14 @@
import asyncio import asyncio
import gc import gc
import hashlib
import importlib import importlib
import inspect import inspect
import json import json
import multiprocessing import multiprocessing
import multiprocessing.forkserver as forkserver import multiprocessing.forkserver as forkserver
import os import os
import secrets
import signal import signal
import socket import socket
import tempfile import tempfile
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
class AuthenticationMiddleware: class AuthenticationMiddleware:
""" """
Pure ASGI middleware that authenticates each request by checking Pure ASGI middleware that authenticates each request by checking
if the Authorization header exists and equals "Bearer {api_key}". if the Authorization Bearer token exists and equals anyof "{api_key}".
Notes Notes
----- -----
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
def __init__(self, app: ASGIApp, tokens: list[str]) -> None: def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
self.app = app self.app = app
self.api_tokens = {f"Bearer {token}" for token in tokens} self.api_tokens = [
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
]
def verify_token(self, headers: Headers) -> bool:
authorization_header_value = headers.get("Authorization")
if not authorization_header_value:
return False
scheme, _, param = authorization_header_value.partition(" ")
if scheme.lower() != "bearer":
return False
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
token_match = False
for token_hash in self.api_tokens:
token_match |= secrets.compare_digest(param_hash, token_hash)
return token_match
def __call__(self, scope: Scope, receive: Receive, def __call__(self, scope: Scope, receive: Receive,
send: Send) -> Awaitable[None]: send: Send) -> Awaitable[None]:
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
url_path = URL(scope=scope).path.removeprefix(root_path) url_path = URL(scope=scope).path.removeprefix(root_path)
headers = Headers(scope=scope) headers = Headers(scope=scope)
# Type narrow to satisfy mypy. # Type narrow to satisfy mypy.
if url_path.startswith("/v1") and headers.get( if url_path.startswith("/v1") and not self.verify_token(headers):
"Authorization") not in self.api_tokens:
response = JSONResponse(content={"error": "Unauthorized"}, response = JSONResponse(content={"error": "Unauthorized"},
status_code=401) status_code=401)
return response(scope, receive, send) return response(scope, receive, send)
@ -1696,6 +1716,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args. exclude_tools_when_tool_choice_none=args.

View File

@ -103,9 +103,13 @@ class FrontendArgs:
chat_template_content_format: ChatTemplateContentFormatOption = "auto" chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template. """The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"` * "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to OpenAI * "openai" will render the content as a list of dictionaries, similar to
schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant" response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`.""" """The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None ssl_keyfile: Optional[str] = None

View File

@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
request_logger: Optional[RequestLogger], request_logger: Optional[RequestLogger],
chat_template: Optional[str], chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False, return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "", reasoning_parser: str = "",
enable_auto_tools: bool = False, enable_auto_tools: bool = False,
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
self.response_role = response_role self.response_role = response_role
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
# set up tool use # set up tool use
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
if not self.use_harmony: if not self.use_harmony:
# Common case. # Common case.
request_chat_template = request.chat_template
chat_template_kwargs = request.chat_template_kwargs
if not self.trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs and
chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
( (
conversation, conversation,
request_prompts, request_prompts,
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
chat_template=request.chat_template or self.chat_template, chat_template=request_chat_template or self.chat_template,
chat_template_content_format=self. chat_template_content_format=self.
chat_template_content_format, chat_template_content_format,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,

View File

@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
@ -725,6 +726,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": "VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Whether to allow HTTP redirects when fetching from media URLs.
# Default to True
"VLLM_MEDIA_URL_ALLOW_REDIRECTS":
lambda: bool(int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1"))),
# Max number of workers for the thread pool handling # Max number of workers for the thread pool handling
# media bytes loading. Set to 1 to disable parallel processing. # media bytes loading. Set to 1 to disable parallel processing.
# Default is 8 # Default is 8

View File

@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
return BatchDescriptor(self.num_tokens, uniform_decode=False) return BatchDescriptor(self.num_tokens, uniform_decode=False)
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int) -> list[int]:
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
sequence_parallel_size)
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
return sp_tokens.tolist()
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
sequence_parallel_size: int,
max_num_tokens: int, max_num_tokens: int,
chunk_idx: int) -> list[int]: chunk_idx: int) -> list[int]:
dp_size = len(num_tokens_across_dp_cpu)
local_size = [-1] * dp_size sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
for i in range(dp_size): sequence_parallel_size)
dp_tokens = num_tokens_across_dp_cpu[i] sp_size = len(sp_tokens)
local_size = [-1] * sp_size
for i in range(sp_size):
# Take into account sharding if MoE activation is sequence parallel.
local_size[i] = min(max_num_tokens, local_size[i] = min(max_num_tokens,
dp_tokens - (max_num_tokens * chunk_idx)) sp_tokens[i] - (max_num_tokens * chunk_idx))
if local_size[i] <= 0: if local_size[i] <= 0:
local_size[i] = 1 # ensure lockstep even if done local_size[i] = 1 # ensure lockstep even if done
return local_size return local_size
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
@dataclass @dataclass
class DPMetadata: class DPMetadata:
max_tokens_across_dp_cpu: torch.Tensor max_tokens_across_dp_cpu: torch.Tensor
cu_tokens_across_dp_cpu: torch.Tensor num_tokens_across_dp_cpu: torch.Tensor
# NOTE: local_sizes should only be set by the chunked_sizes context manager
local_sizes: Optional[list[int]] = None local_sizes: Optional[list[int]] = None
@staticmethod @staticmethod
@ -98,6 +113,17 @@ class DPMetadata:
dist.all_reduce(num_tokens_tensor, group=group) dist.all_reduce(num_tokens_tensor, group=group)
return num_tokens_tensor.cpu() return num_tokens_tensor.cpu()
# Get the cumulative tokens across sequence parallel ranks.
# In this case the input to the MoEs will be distributed w.r.t both
# DP and TP rank.
# When sp_size==1, this is just the cummulative num tokens across DP.
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
num_tokens_across_sp_cpu = (
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
num_tokens_across_sp_cpu = (
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
@staticmethod @staticmethod
def should_ubatch_across_dp( def should_ubatch_across_dp(
should_ubatch: bool, orig_num_tokens_per_ubatch: int, should_ubatch: bool, orig_num_tokens_per_ubatch: int,
@ -147,10 +173,10 @@ class DPMetadata:
@staticmethod @staticmethod
def make( def make(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
attn_metadata: Any, attn_metadata: Any,
num_tokens: int, num_tokens: int,
num_tokens_across_dp: Optional[torch.Tensor] = None num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
) -> "DPMetadata": ) -> "DPMetadata":
assert parallel_config.data_parallel_size > 1 assert parallel_config.data_parallel_size > 1
@ -167,18 +193,18 @@ class DPMetadata:
# If num_tokens_across_dp is None, it will be computed by all_reduce # If num_tokens_across_dp is None, it will be computed by all_reduce
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank] assert (num_tokens_across_dp_cpu is None
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}" or num_tokens_across_dp_cpu[dp_rank] == batchsize
if num_tokens_across_dp is None: ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
num_tokens_across_dp = DPMetadata.num_tokens_across_dp( if num_tokens_across_dp_cpu is None:
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
batchsize, dp_size, dp_rank) batchsize, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0) return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
num_tokens_across_dp)
@contextmanager @contextmanager
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int): def chunked_sizes(self, sequence_parallel_size: int,
max_chunk_size_per_rank: int, chunk_idx: int):
""" """
Context manager to compute and temporarily set the per-rank local token Context manager to compute and temporarily set the per-rank local token
sizes for a specific chunk during chunked forward execution. sizes for a specific chunk during chunked forward execution.
@ -192,31 +218,40 @@ class DPMetadata:
`chunk_idx`, this context manager sets `self.local_sizes` to the number `chunk_idx`, this context manager sets `self.local_sizes` to the number
of tokens to process in that chunk on each rank. of tokens to process in that chunk on each rank.
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
to determine the chunk-wise split.
`self.local_sizes` is only valid inside the context. `self.local_sizes` is only valid inside the context.
Args: Args:
sequence_parallel_size: When Attn is TP and MoE layers are EP,
we use SP between the layers to avoid
redundant ops. We need this value to
compute the chunked sizes.
max_chunk_size_per_rank: The max number of tokens each rank is max_chunk_size_per_rank: The max number of tokens each rank is
allowed to process in this chunk. allowed to process in this chunk.
chunk_idx: The index of the chunk to compute sizes for. chunk_idx: The index of the chunk to compute sizes for.
""" """
cu_sizes = self.cu_tokens_across_dp_cpu
num_tokens_across_dp_cpu = [
(cu_sizes[i] -
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
for i in range(len(cu_sizes))
]
self.local_sizes = _compute_chunked_local_num_tokens( self.local_sizes = _compute_chunked_local_num_tokens(
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx) self.num_tokens_across_dp_cpu, sequence_parallel_size,
max_chunk_size_per_rank, chunk_idx)
try:
yield self.local_sizes
finally:
self.local_sizes = None
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
"""
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
but without any chunking.
"""
self.local_sizes = _compute_sp_num_tokens(
self.num_tokens_across_dp_cpu, sequence_parallel_size)
try: try:
yield self.local_sizes yield self.local_sizes
finally: finally:
self.local_sizes = None self.local_sizes = None
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
assert self.local_sizes is not None
return self.local_sizes return self.local_sizes

View File

@ -3,6 +3,7 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from contextlib import nullcontext
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, Union, get_args, overload from typing import Callable, Literal, Optional, Union, get_args, overload
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
if dp_size is not None else get_dp_group().world_size) if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel: self.sp_size = tp_size_ if is_sequence_parallel else 1
self.sp_size = tp_size_
self.moe_parallel_config: FusedMoEParallelConfig = ( self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make( FusedMoEParallelConfig.make(
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
# clamp start and end # clamp start and end
chunk_start = min(chunk_start, num_tokens - 1) chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens) chunk_end = min(chunk_end, num_tokens)
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank, with ctx.dp_metadata.chunked_sizes(self.sp_size,
moe_dp_chunk_size_per_rank,
chunk_idx): chunk_idx):
process_chunk(chunk_start, process_chunk(chunk_start,
chunk_end, chunk_end,
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
else: else:
shared_output = None shared_output = None
if do_naive_dispatch_combine: ctx = get_forward_context()
hidden_states, router_logits = get_ep_group().dispatch( sp_ctx = ctx.dp_metadata.sp_local_sizes(
hidden_states, router_logits) self.sp_size) if ctx.dp_metadata else nullcontext()
# Matrix multiply. with sp_ctx:
final_hidden_states = self.quant_method.apply( if do_naive_dispatch_combine:
layer=self, hidden_states, router_logits = get_ep_group().dispatch(
x=hidden_states, hidden_states, router_logits, self.is_sequence_parallel)
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
)
if shared_output is not None: # Matrix multiply.
assert not isinstance(final_hidden_states, tuple) final_hidden_states = self.quant_method.apply(
assert self.shared_experts is not None layer=self,
final_hidden_states = ( x=hidden_states,
shared_output, router_logits=router_logits,
final_hidden_states, top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
enable_eplb=self.enable_eplb,
expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count,
) )
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
def reduce_output(states: torch.Tensor, if shared_output is not None:
do_combine: bool = True) -> torch.Tensor: assert not isinstance(final_hidden_states, tuple)
if do_naive_dispatch_combine and do_combine: assert self.shared_experts is not None
states = get_ep_group().combine(states) final_hidden_states = (
shared_output,
final_hidden_states,
)
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
assert isinstance(final_hidden_states, tuple)
final_hidden_states, zero_expert_result = final_hidden_states
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): def reduce_output(states: torch.Tensor,
states = self.maybe_all_reduce_tensor_model_parallel(states) do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states,
self.is_sequence_parallel)
return states if (not self.is_sequence_parallel and self.reduce_results
and (self.tp_size > 1 or self.ep_size > 1)):
states = self.maybe_all_reduce_tensor_model_parallel(
states)
if self.shared_experts is not None: return states
return (
reduce_output(final_hidden_states[0], do_combine=False), if self.shared_experts is not None:
reduce_output(final_hidden_states[1]), return (
) reduce_output(final_hidden_states[0], do_combine=False),
elif self.zero_expert_num is not None and self.zero_expert_num > 0: reduce_output(final_hidden_states[1]),
assert isinstance(final_hidden_states, torch.Tensor) )
return reduce_output(final_hidden_states) + zero_expert_result elif self.zero_expert_num is not None and self.zero_expert_num > 0:
else: assert isinstance(final_hidden_states, torch.Tensor)
return reduce_output(final_hidden_states) return reduce_output(final_hidden_states) + zero_expert_result
else:
return reduce_output(final_hidden_states)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(

View File

@ -5,6 +5,7 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
@ -375,3 +376,20 @@ class PolyNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
return poly_norm(x, self.weight, self.bias, self.variance_epsilon) return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
class LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
self.eps).type_as(x)

View File

@ -24,6 +24,9 @@ class MLAModules:
q_a_layernorm: Optional[torch.nn.Module] q_a_layernorm: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module] q_b_proj: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module] q_proj: Optional[torch.nn.Module]
indexer: Optional[torch.nn.Module]
is_sparse: bool
topk_indices_buffer: Optional[torch.Tensor]
@CustomOp.register("multi_head_latent_attention") @CustomOp.register("multi_head_latent_attention")
@ -76,6 +79,13 @@ class MultiHeadLatentAttention(CustomOp):
self.kv_b_proj = mla_modules.kv_b_proj self.kv_b_proj = mla_modules.kv_b_proj
self.rotary_emb = mla_modules.rotary_emb self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and # In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular, # pe (i.e. decoupled position embeddings). In particular,
@ -92,6 +102,7 @@ class MultiHeadLatentAttention(CustomOp):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_mla=True, use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args # MLA Args
q_lora_rank=self.q_lora_rank, q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank, kv_lora_rank=self.kv_lora_rank,
@ -100,6 +111,7 @@ class MultiHeadLatentAttention(CustomOp):
qk_head_dim=self.qk_head_dim, qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim, v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj, kv_b_proj=self.kv_b_proj,
indexer=self.indexer,
) )
self.prefix = prefix self.prefix = prefix
@ -145,6 +157,10 @@ class MultiHeadLatentAttention(CustomOp):
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe) positions, q[..., self.qk_nope_head_dim:], k_pe)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,

View File

@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
# requantize the weight and input to the specific scale # requantize the weight and input to the specific scale
# at the same time. # at the same time.
if is_deep_gemm_e8m0_used(): should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
layer.orig_dtype, layer.weight)
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
block_sz = tuple(layer.weight_block_size) block_sz = tuple(layer.weight_block_size)
requant_weight_ue8m0_inplace(layer.weight.data, requant_weight_ue8m0_inplace(layer.weight.data,
layer.weight_scale.data, block_sz) layer.weight_scale.data, block_sz)
# SM90 Block FP8 CUTLASS requires row-major weight scales # SM90 Block FP8 CUTLASS requires row-major weight scales
elif (current_platform.is_device_capability(90) elif (current_platform.is_device_capability(90)
and cutlass_block_fp8_supported and cutlass_block_fp8_supported and not should_use_deepgemm):
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
layer.weight)):
layer.weight_scale = torch.nn.Parameter( layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False) layer.weight_scale.data.T.contiguous(), requires_grad=False)

View File

@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor from transformers.models.aria.processing_aria import AriaProcessor
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig from vllm.config import QuantizationConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -298,14 +298,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
Experts (MoE) Layer. Experts (MoE) Layer.
""" """
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
self, super().__init__(vllm_config, prefix)
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None, config = vllm_config.model_config.hf_config
quant_config: Optional[QuantizationConfig] = None, quant_config = vllm_config.quant_config
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
self.mlp = AriaTextMoELayer(config, self.mlp = AriaTextMoELayer(config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")

View File

@ -346,8 +346,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
block_size=1, block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config), num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(), head_size=model_config.get_head_size(),
dtype=kv_cache_dtype, dtype=kv_cache_dtype).page_size_bytes
use_mla=model_config.use_mla).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls( model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture, model_config.architecture,
@ -401,6 +400,31 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
"exactly equal.", mamba_padding_pct) "exactly equal.", mamba_padding_pct)
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config = vllm_config.model_config.hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32 = hasattr(hf_config, "index_topk")
assert is_v32
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
# "auto")
cache_config = vllm_config.cache_config
if cache_config.cache_dtype == "auto" or \
cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig, "GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig, "GteNewModel": GteNewModelConfig,
@ -417,4 +441,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"MambaForCausalLM": MambaModelConfig, "MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig, "Mamba2ForCausalLM": MambaModelConfig,
"FalconMambaForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
} }

View File

@ -53,8 +53,20 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
self.eh_proj = nn.Linear(config.hidden_size * 2, self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config) self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
def forward( def forward(
self, self,

View File

@ -32,17 +32,22 @@ import torch
from torch import nn from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.attention import Attention from vllm.attention import Attention
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
@ -50,20 +55,35 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv, direct_register_custom_op from vllm.utils import cdiv, direct_register_custom_op
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
logger = init_logger(__name__)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
@ -108,43 +128,6 @@ class DeepseekV2MLP(nn.Module):
return x return x
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
x = nn.functional.pad(x, (0, 0, 0, pad_len))
chunk = x.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(x, 0, start, chunk)
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk",
op_func=sequence_parallel_chunk,
fake_impl=sequence_parallel_chunk_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
def __init__( def __init__(
@ -166,20 +149,7 @@ class DeepseekV2MoE(nn.Module):
self.n_routed_experts: int = config.n_routed_experts self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts self.n_shared_experts: int = config.n_shared_experts
# The all_reduce at the end of attention (during o_proj) means that self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
# inputs are replicated across each rank of the tensor parallel group.
# If using expert-parallelism with DeepEP All2All ops, replicated
# tokens results in useless duplicate computation and communication.
#
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
# Not needed for pplx-kernels as it can handle duplicate input tokens.
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
in ("deepep_high_throughput",
"deepep_low_latency")
and parallel_config.enable_expert_parallel
and self.tp_size > 1)
if config.hidden_act != "silu": if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. " raise ValueError(f"Unsupported activation: {config.hidden_act}. "
@ -278,8 +248,7 @@ class DeepseekV2MoE(nn.Module):
# TODO: We can replace the all_reduce at the end of attn with a # TODO: We can replace the all_reduce at the end of attn with a
# reduce_scatter instead of chunking here. # reduce_scatter instead of chunking here.
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = torch.ops.vllm.sequence_parallel_chunk( hidden_states = sequence_parallel_chunk(hidden_states)
hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
@ -328,6 +297,7 @@ class DeepseekV2Attention(nn.Module):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config], config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
@ -341,6 +311,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
topk_indices_buffer: Optional[torch.Tensor] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
@ -358,6 +329,8 @@ class DeepseekV2Attention(nn.Module):
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
assert topk_indices_buffer is None, "topk_indices_buffer is not \
supported for DeepseekV2Attention"
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_a_proj = ReplicatedLinear(self.hidden_size,
@ -470,6 +443,390 @@ class DeepseekV2Attention(nn.Module):
return output return output
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
cache_config: CacheConfig):
super().__init__()
self.kv_cache = [torch.tensor([])]
self.head_dim = head_dim
self.prefix = prefix
self.cache_config = cache_config
self.dtype = dtype
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def get_kv_cache_spec(self) -> KVCacheSpec:
return MLAAttentionSpec( # Only has one vector instead of K + V
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
)
def forward(self):
...
def get_attn_backend(self) -> AttentionBackend:
return DeepseekV32IndexerBackend
@torch.inference_mode()
def cp_gather_indexer_k_quant_cache(
kv_cache, # [num_blocks, block_size, head_dim + 1]
dst_value, # [cu_seq_lens[-1], head_dim]
dst_scale, # [cu_seq_lens[-1], 4]
block_table, # [batch_size, num_blocks]
cu_seq_lens, # [batch_size + 1, ]
batch_size,
):
num_blocks, block_size, _ = kv_cache.shape
head_dim = dst_value.shape[-1]
kv_cache = kv_cache.view(num_blocks, -1)
expected_value = []
expected_scale = []
for b in range(batch_size):
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
if s == 0:
continue
tot = cdiv(s, block_size)
blocks = block_table[b, :tot]
value = []
scale = []
full_block = torch.arange(tot - 1,
device=kv_cache.device,
dtype=torch.int32)
non_remaining_value = kv_cache[blocks[full_block], :block_size *
head_dim].view(-1, head_dim)
non_remaining_scale = kv_cache[blocks[full_block],
block_size * head_dim:].view(-1, 4)
remaining = s - (tot - 1) * block_size
value = torch.cat([
non_remaining_value,
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
],
dim=0)
scale = torch.cat([
non_remaining_scale,
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
remaining * 4].view(-1, 4)
],
dim=0)
expected_value.append(value)
expected_scale.append(scale)
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
gather_value = gather_value.view(torch.float8_e4m3fn)
gather_scale = gather_scale.view(torch.float32)
dst_value.copy_(gather_value)
dst_scale.copy_(gather_scale)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn)
k_scale = torch.empty([chunk.total_seq_lens, 1],
device=k.device,
dtype=torch.float32)
cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
(k_fp8, k_scale),
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
dim=-1)[1]
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
chunk.cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[
chunk.token_start:chunk.token_end, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:])
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
positions = torch.arange(max_model_len,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
row_indices = torch.arange(padded_num_tokens,
device=current_device) // next_n
next_n_offset = torch.arange(
padded_num_tokens,
device=padded_q_fp8_decode_tokens.device) % next_n
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
topk_indices = logits.topk(topk_tokens,
dim=-1)[1].to(torch.int32) # [B * N, K]
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens)
topk_indices_buffer[:num_decode_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
return topk_indices_buffer
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device,
dtype=torch.uint8)
_k_fp8 = _flattened_kv[..., :head_dim].view(
torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer
direct_register_custom_op(
op_name="sparse_attn_indexer",
op_func=sparse_attn_indexer,
mutates_args=["topk_indices_buffer"],
fake_impl=sparse_attn_indexer_fake,
dispatch_key=current_platform.dispatch_key,
)
class Indexer(nn.Module):
def __init__(self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
q_lora_rank: int,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
topk_indices_buffer: Optional[torch.Tensor],
prefix: str = ""):
super().__init__()
self.vllm_config = vllm_config
self.config = config
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64
self.head_dim = config.index_head_dim # 128
self.rope_dim = config.qk_rope_head_dim # 64
self.q_lora_rank = q_lora_rank # 1536
# no tensor parallel, just replicated
self.wq_b = ReplicatedLinear(self.q_lora_rank,
self.head_dim * self.n_head,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b")
self.wk = ReplicatedLinear(hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk")
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.weights_proj = ReplicatedLinear(hidden_size,
self.n_head,
quant_config=None,
prefix=f"{prefix}.weights_proj")
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"
self.quant_block_size = 128 # TODO: get from config
self.topk_indices_buffer = topk_indices_buffer
# NOTE: (zyongye) we use fp8 naive cache,
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim +
self.head_dim // self.quant_block_size * 4,
dtype=torch.uint8,
prefix=f"{prefix}.k_cache",
cache_config=cache_config)
self.max_model_len = vllm_config.model_config.max_model_len
self.prefix = prefix
from vllm.v1.attention.backends.mla.indexer import (
get_max_prefill_buffer_size)
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
rotary_emb) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
k, _ = self.wk(hidden_states)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt
is not None)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = weights.unsqueeze(
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
class DeepseekV2MLAAttention(nn.Module): class DeepseekV2MLAAttention(nn.Module):
""" """
Main reference: DeepseekV2 paper, and FlashInfer Implementation Main reference: DeepseekV2 paper, and FlashInfer Implementation
@ -481,6 +838,7 @@ class DeepseekV2MLAAttention(nn.Module):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config], config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
@ -495,6 +853,7 @@ class DeepseekV2MLAAttention(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
topk_indices_buffer: Optional[torch.Tensor] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
@ -575,6 +934,15 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
self.indexer = Indexer(vllm_config, config, hidden_size,
q_lora_rank, quant_config, cache_config,
topk_indices_buffer, f"{prefix}.indexer")
else:
self.indexer = None
mla_modules = MLAModules( mla_modules = MLAModules(
kv_a_layernorm=self.kv_a_layernorm, kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj, kv_b_proj=self.kv_b_proj,
@ -588,7 +956,11 @@ class DeepseekV2MLAAttention(nn.Module):
if self.q_lora_rank is not None else None, if self.q_lora_rank is not None else None,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else None, q_proj=self.q_proj if self.q_lora_rank is None else None,
indexer=self.indexer,
is_sparse=self.is_v32,
topk_indices_buffer=topk_indices_buffer,
) )
self.mla_attn = MultiHeadLatentAttention( self.mla_attn = MultiHeadLatentAttention(
self.hidden_size, self.hidden_size,
self.num_local_heads, self.num_local_heads,
@ -614,7 +986,10 @@ class DeepseekV2MLAAttention(nn.Module):
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: def __init__(self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -637,6 +1012,7 @@ class DeepseekV2DecoderLayer(nn.Module):
else: else:
attn_cls = DeepseekV2Attention attn_cls = DeepseekV2Attention
self.self_attn = attn_cls( self.self_attn = attn_cls(
vllm_config=vllm_config,
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
@ -652,6 +1028,7 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
topk_indices_buffer=topk_indices_buffer,
) )
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
@ -735,6 +1112,16 @@ class DeepseekV2Model(nn.Module):
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
@ -747,7 +1134,8 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix), lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:

View File

@ -29,10 +29,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
prefix: str, prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_emb_norm = RMSNorm(config.hidden_size, self.mtp_emb_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2, self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config, self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
prefix)
def forward( def forward(
self, self,
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleDict({ self.layers = torch.nn.ModuleDict({
str(idx): str(idx):
ErnieMultiTokenPredictorLayer( ErnieMultiTokenPredictorLayer(
config, vllm_config,
f"{prefix}.layers.{idx}", f"{prefix}.layers.{idx}",
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
) )
for idx in range(self.mtp_start_layer_idx, for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers) self.mtp_start_layer_idx + self.num_mtp_layers)

View File

@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
class Glm4DecoderLayer(nn.Module): class Glm4DecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: Glm4Config, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[Glm4Config] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 1000000) rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)

View File

@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv from vllm.utils import cdiv
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
layer_idx: int, layer_idx: int,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok self.experts_per_token = config.num_experts_per_tok
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
apply_router_weight_on_input=False, apply_router_weight_on_input=False,
has_bias=True, has_bias=True,
activation="swigluoai") activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
num_tokens = x.shape[0]
if self.is_sequence_parallel:
x = sequence_parallel_chunk(x)
g = self.router(x) g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g) x = self.experts(hidden_states=x, router_logits=g)
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
x = x[:num_tokens]
return x return x
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
def __init__( def __init__(
self, self,
config: GptOssConfig, vllm_config: VllmConfig,
cache_config: CacheConfig,
quant_config: QuantizationConfig,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
self.attn = OAIAttention(config, self.attn = OAIAttention(config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
cache_config=cache_config) cache_config=cache_config)
self.mlp = MLPBlock(config, self.mlp = MLPBlock(vllm_config,
self.layer_idx, self.layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.mlp") prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
): ):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.cache_config = vllm_config.cache_config
self.quant_config = vllm_config.quant_config
self.parallel_config = vllm_config.parallel_config self.parallel_config = vllm_config.parallel_config
self.config.hidden_size = self.config.hidden_size self.config.hidden_size = self.config.hidden_size
self.embedding = VocabParallelEmbedding( self.embedding = VocabParallelEmbedding(
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers, self.config.num_hidden_layers,
lambda prefix: TransformerBlock( lambda prefix: TransformerBlock(
self.config, vllm_config,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=prefix, prefix=prefix,
), ),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",

View File

@ -29,12 +29,13 @@ from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
is_sequence_parallel=False,
prefix: str = ""): prefix: str = ""):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.is_sequence_parallel = is_sequence_parallel
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size, self.gate = ReplicatedLinear(hidden_size,
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
prefix=f"{prefix}.experts") prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size) hidden_states = hidden_states.view(-1, self.hidden_size)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
num_tokens = orig_shape[0]
final_hidden_states = final_hidden_states[:num_tokens]
return final_hidden_states.view(orig_shape) return final_hidden_states.view(orig_shape)
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GraniteMoeConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
prefix=f"{prefix}.block_sparse_moe") prefix=f"{prefix}.block_sparse_moe")
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: GraniteMoeDecoderLayer( lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

View File

@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
bias: bool = False, bias: bool = False,
prefix: str = "", prefix: str = "",
reduce_results: bool = True, reduce_results: bool = True,
disable_tp: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
disable_tp=disable_tp,
prefix=f"{prefix}.gate_up_proj", prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
disable_tp=disable_tp,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
if hidden_act != "silu": if hidden_act != "silu":
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: LlamaConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[LlamaConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer() self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: layer_type(config=config, lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:

View File

@ -28,7 +28,8 @@ from vllm.attention import Attention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
router_scores = torch.sigmoid(router_scores.float()) router_scores = torch.sigmoid(router_scores.float())
return (router_scores, router_indices.to(torch.int32)) return (router_scores, router_indices.to(torch.int32))
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
config: Llama4TextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
intermediate_size_moe = config.intermediate_size intermediate_size_moe = config.intermediate_size
self.router = ReplicatedLinear(config.hidden_size, self.router = ReplicatedLinear(config.hidden_size,
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
bias=False, bias=False,
prefix=f"{prefix}.shared_expert", prefix=f"{prefix}.shared_expert",
reduce_results=False, reduce_results=False,
disable_tp=self.is_sequence_parallel,
) )
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
is_sequence_parallel=self.is_sequence_parallel,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
num_tokens = hidden_states.shape[0]
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
shared_out, routed_out = self.experts( shared_out, routed_out = self.experts(
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
) )
experts_out = routed_out + shared_out experts_out = routed_out + shared_out
if self.tp_size > 1: if self.is_sequence_parallel:
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
experts_out = experts_out[:num_tokens]
elif self.tp_size > 1:
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel( experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
experts_out) experts_out)
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
class Llama4DecoderLayer(nn.Module): class Llama4DecoderLayer(nn.Module):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: Llama4TextConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[Llama4TextConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
self.layer_idx + 1) % config.interleave_moe_layer_step == 0 self.layer_idx + 1) % config.interleave_moe_layer_step == 0
if is_moe_layer: if is_moe_layer:
self.feed_forward = Llama4MoE( self.feed_forward = Llama4MoE(
config=config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward", prefix=f"{prefix}.feed_forward",
) )
else: else:

View File

@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Llama4DecoderLayer( Llama4DecoderLayer(
self.config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers) ) for i in range(self.config.num_hidden_layers)
]) ])
self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__( def __init__(
self, self,
config: LlamaConfig, vllm_config: VllmConfig,
disable_input_layernorm: bool, disable_input_layernorm: bool,
prefix: str = "", prefix: str = "",
config: Optional[LlamaConfig] = None,
) -> None: ) -> None:
super().__init__(config, prefix=prefix) super().__init__(vllm_config, prefix=prefix, config=config)
# Skip the input_layernorm # Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer( LlamaDecoderLayer(
self.config, vllm_config,
i == 0, i == 0,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) for i in range(self.config.num_hidden_layers) ) for i in range(self.config.num_hidden_layers)
]) ])
self.fc = torch.nn.Linear(self.config.hidden_size * 2, self.fc = torch.nn.Linear(self.config.hidden_size * 2,

View File

@ -9,13 +9,11 @@ import torch.nn as nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -29,17 +27,14 @@ logger = init_logger(__name__)
class LlamaDecoderLayer(LlamaDecoderLayer): class LlamaDecoderLayer(LlamaDecoderLayer):
def __init__( def __init__(self,
self, vllm_config: VllmConfig,
config: LlamaConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, config: Optional[LlamaConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None, super().__init__(vllm_config, prefix=prefix, config=config)
prefix: str = "",
) -> None: config = config or vllm_config.model_config.hf_config
super().__init__(config, quant_config = vllm_config.quant_config
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix)
# override qkv # override qkv
self.self_attn.qkv_proj = QKVParallelLinear( self.self_attn.qkv_proj = QKVParallelLinear(
@ -127,9 +122,9 @@ class LlamaModel(nn.Module):
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer( LlamaDecoderLayer(
config=self.config, current_vllm_config,
cache_config=current_vllm_config.cache_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
config=self.config,
) )
]) ])
if hasattr(self.config, "target_hidden_size"): if hasattr(self.config, "target_hidden_size"):

View File

@ -308,6 +308,7 @@ class FlashDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
vllm_config: VllmConfig,
config: FlashConfig, config: FlashConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
@ -329,6 +330,7 @@ class FlashDecoderLayer(nn.Module):
# Dual attention structure # Dual attention structure
self.self_attn = nn.ModuleList([ self.self_attn = nn.ModuleList([
DeepseekV2MLAAttention( DeepseekV2MLAAttention(
vllm_config=vllm_config,
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
@ -454,6 +456,7 @@ class FlashModel(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: FlashDecoderLayer( lambda prefix: FlashDecoderLayer(
vllm_config,
config, config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,

View File

@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
# Per attention head and per partition values. # Per attention head and per partition values.
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
disable_tp=use_data_parallel) disable_tp=use_data_parallel)
self.attn_backend = attn_backend
# Detect attention implementation. self.use_upstream_fa = use_upstream_fa
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in { self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
} }
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel) use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen2_5_VisionMLP(dim, self.mlp = Qwen2_5_VisionMLP(dim,
mlp_hidden_dim, mlp_hidden_dim,
act_fn=act_fn, act_fn=act_fn,
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
use_upstream_fa = False
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
Qwen2_5_VisionBlock(dim=self.hidden_size, Qwen2_5_VisionBlock(
num_heads=self.num_heads, dim=self.hidden_size,
mlp_hidden_dim=vision_config.intermediate_size, num_heads=self.num_heads,
act_fn=get_act_and_mul_fn( mlp_hidden_dim=vision_config.intermediate_size,
vision_config.hidden_act), act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}", prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel) use_data_parallel=use_data_parallel,
for layer_idx in range(depth) attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
]) ])
self.merger = Qwen2_5_VisionPatchMerger( self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel, use_data_parallel=use_data_parallel,
) )
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:

View File

@ -79,7 +79,7 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__) logger = init_logger(__name__)
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 32 _MAX_FRAMES_PER_VIDEO = 14
# === Vision Inputs === # # === Vision Inputs === #
@ -932,6 +932,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
_, num_image_tokens = self._get_vision_info( _, num_image_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
num_frames=1,
image_processor=image_processor, image_processor=image_processor,
) )
return num_image_tokens return num_image_tokens
@ -956,6 +957,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=9999999, image_width=9999999,
image_height=9999999, image_height=9999999,
num_frames=1,
image_processor=None, image_processor=None,
) )
return max_image_size return max_image_size
@ -969,10 +971,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_processor=None, image_processor=None,
) )
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 1) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0 num_frames = start_num_frames
while True: while True:
next_num_frames = num_frames + 1 next_num_frames = num_frames + 1
@ -994,12 +998,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
) -> int: ) -> int:
max_videos = mm_counts.get("video", 0) max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len) max_total_frames = self._get_max_video_frames(seq_len)
max_frames_per_video = min(max_total_frames // max(max_videos, 1), max_frames_per_video = min(max_total_frames // max(max_videos, 1),
_MAX_FRAMES_PER_VIDEO) max_frames_per_video)
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)

View File

@ -29,13 +29,13 @@ from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen3MoeConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group, from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name) default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3MoeConfig, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
assert hidden_states.dim( assert hidden_states.dim(
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
is_input_1d = hidden_states.dim() == 1 is_input_1d = hidden_states.dim() == 1
hidden_dim = hidden_states.shape[-1] num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
# return to 1d if input is 1d # return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else \ return final_hidden_states.squeeze(0) if is_input_1d else \
final_hidden_states final_hidden_states
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
class Qwen3MoeDecoderLayer(nn.Module): class Qwen3MoeDecoderLayer(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
self,
config: Qwen3MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
if (layer_idx not in mlp_only_layers) and ( if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0): (layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3MoeSparseMoeBlock(config=config, self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
quant_config=quant_config, prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb)
else: else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
@ -361,11 +372,9 @@ class Qwen3MoeModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config.get_text_config() config = vllm_config.model_config.hf_text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts self.num_redundant_experts = eplb_config.num_redundant_experts
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
prefix=f"{prefix}.embed_tokens") prefix=f"{prefix}.embed_tokens")
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Qwen3MoeDecoderLayer(config=config, lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
cache_config=cache_config, prefix=prefix),
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -580,7 +586,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_text_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config

View File

@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config) VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group, from vllm.distributed import (divide, get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import (
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader) default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module): class Qwen3NextSparseMoeBlock(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None shared_output = None
if self.shared_expert is not None: if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states) shared_output = self.shared_expert(hidden_states)
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states) final_hidden_states)
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3NextConfig, vllm_config: VllmConfig,
layer_type: str, layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
config.num_experts > 0 and config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0): (self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock( self.mlp = Qwen3NextSparseMoeBlock(
config=config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
) )
else: else:
self.mlp = Qwen3NextMLP( self.mlp = Qwen3NextMLP(
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
self.ffn_layer_scale = torch.nn.Parameter( self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
super().__init__() super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts self.num_redundant_experts = eplb_config.num_redundant_experts
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
def get_layer(prefix: str): def get_layer(prefix: str):
return Qwen3NextDecoderLayer( return Qwen3NextDecoderLayer(
config, vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)], layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix, prefix=prefix,
enable_eplb=enable_eplb,
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(

View File

@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
super().__init__() super().__init__()
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config config: Qwen3NextConfig = model_config.hf_config
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer( Qwen3NextDecoderLayer(
config, vllm_config,
layer_type="full_attention", layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{idx}', prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers)) ) for idx in range(self.num_mtp_layers))

View File

@ -33,11 +33,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BatchFeature from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
smart_resize as image_smart_resize)
from transformers.models.qwen3_vl import (Qwen3VLProcessor, from transformers.models.qwen3_vl import (Qwen3VLProcessor,
Qwen3VLVideoProcessor) Qwen3VLVideoProcessor)
from transformers.models.qwen3_vl.configuration_qwen3_vl import ( from transformers.models.qwen3_vl.configuration_qwen3_vl import (
Qwen3VLConfig, Qwen3VLVisionConfig) Qwen3VLConfig, Qwen3VLVisionConfig)
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
smart_resize as video_smart_resize)
from transformers.video_utils import VideoMetadata from transformers.video_utils import VideoMetadata
from vllm.attention.layer import check_upstream_fa_availability from vllm.attention.layer import check_upstream_fa_availability
@ -84,6 +87,9 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
logger = init_logger(__name__) logger = init_logger(__name__)
# Official recommended max pixels is 24576 * 32 * 32
_MAX_FRAMES_PER_VIDEO = 24576
class Qwen3_VisionPatchEmbed(nn.Module): class Qwen3_VisionPatchEmbed(nn.Module):
@ -158,6 +164,8 @@ class Qwen3_VisionBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
use_data_parallel: bool = False, use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
@ -170,7 +178,9 @@ class Qwen3_VisionBlock(nn.Module):
projection_size=dim, projection_size=dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel) use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa)
self.mlp = Qwen3_VisionMLP(dim, self.mlp = Qwen3_VisionMLP(dim,
mlp_hidden_dim, mlp_hidden_dim,
act_fn=act_fn, act_fn=act_fn,
@ -287,19 +297,6 @@ class Qwen3_VisionTransformer(nn.Module):
head_dim = self.hidden_size // self.num_heads head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(vision_config.depth)
])
self.merger = Qwen3_VisionPatchMerger( self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size, context_dim=self.hidden_size,
@ -325,10 +322,34 @@ class Qwen3_VisionTransformer(nn.Module):
self.attn_backend = get_vit_attn_backend( self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()) head_size=head_dim, dtype=torch.get_default_dtype())
use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \ if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability( check_upstream_fa_availability(
torch.get_default_dtype()): torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
}:
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now.")
self.blocks = nn.ModuleList([
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa)
for layer_idx in range(vision_config.depth)
])
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
@ -569,11 +590,16 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height: int, image_height: int,
num_frames: int = 2, num_frames: int = 2,
do_resize: bool = True, do_resize: bool = True,
image_processor: Optional[Qwen2VLImageProcessorFast], image_processor: Optional[Union[Qwen2VLImageProcessorFast,
Qwen3VLVideoProcessor]],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
if image_processor is None: if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
elif image_processor is None:
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
@ -581,12 +607,22 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
if do_resize: if do_resize:
if is_video:
smart_resize = video_smart_resize
extra_kwargs = {
"num_frames": num_frames,
"temporal_factor": temporal_patch_size
}
else:
smart_resize = image_smart_resize
extra_kwargs = {}
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * merge_size, factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"], min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"], max_pixels=image_processor.size["longest_edge"],
**extra_kwargs,
) )
preprocessed_size = ImageSize(width=resized_width, preprocessed_size = ImageSize(width=resized_width,
height=resized_height) height=resized_height)
@ -605,6 +641,39 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
return preprocessed_size, num_vision_tokens return preprocessed_size, num_vision_tokens
def _get_max_video_frames(self,
max_tokens: int,
start_num_frames: int = 2) -> int:
return super()._get_max_video_frames(max_tokens,
start_num_frames=start_num_frames)
def get_num_frames_with_most_features(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
return super().get_num_frames_with_most_features(
seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO)
def get_max_video_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(
seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
def _calculate_timestamps(self, indices: list[int] | torch.Tensor, def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
video_fps: float, merge_size: int): video_fps: float, merge_size: int):
if not isinstance(indices, list): if not isinstance(indices, list):
@ -674,6 +743,12 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
self.info.get_image_size_with_most_features()) self.info.get_image_size_with_most_features())
target_num_frames = self.info.get_num_frames_with_most_features( target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts) seq_len, mm_counts)
target_video_size, _ = self.info._get_vision_info(
image_width=target_width,
image_height=target_height,
num_frames=target_num_frames,
image_processor=self.info.get_video_processor(),
)
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=target_width,
@ -681,8 +756,8 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
num_images=num_images), num_images=num_images),
"video": "video":
self._get_dummy_videos( self._get_dummy_videos(
width=target_width, width=target_video_size.width,
height=target_height, height=target_video_size.height,
num_frames=target_num_frames, num_frames=target_num_frames,
num_videos=num_videos, num_videos=num_videos,
), ),
@ -1051,14 +1126,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if not multimodal_config.get_limit_per_prompt("image") and \
self.visual = Qwen3_VisionTransformer( not multimodal_config.get_limit_per_prompt("video"):
config.vision_config, self.visual = None
norm_eps=getattr(config, "rms_norm_eps", 1e-6), else:
quant_config=quant_config, self.visual = Qwen3_VisionTransformer(
prefix=maybe_prefix(prefix, "visual"), config.vision_config,
use_data_parallel=self.use_data_parallel, norm_eps=getattr(config, "rms_norm_eps", 1e-6),
) quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix( prefix=maybe_prefix(
@ -1074,11 +1152,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.vision_config.deepstack_visual_indexes config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0 ) if self.use_deepstack else 0
# register buffer for deepstack # register buffer for deepstack
self.deepstack_input_embeds = [ if self.use_deepstack and self.visual is not None:
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens, self.deepstack_input_embeds = [
config.text_config.hidden_size) torch.zeros(
for _ in range(self.deepstack_num_level) vllm_config.scheduler_config.max_num_batched_tokens,
] if self.use_deepstack else None config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level self.multiscale_dim = self.visual_dim * self.deepstack_num_level
@ -1513,7 +1595,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
skip_prefixes = []
if self.visual is None:
skip_prefixes.extend(["visual."])
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys: def get_mm_mapping(self) -> MultiModelKeys:

View File

@ -212,6 +212,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
# attempted to load as other weights later # attempted to load as other weights later
is_expert_weight = True is_expert_weight = True
name_mapped = name.replace(weight_name, param_name) name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if is_fused_expert: if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1, loaded_weight = loaded_weight.transpose(-1,
-2) # no bias -2) # no bias
@ -230,8 +232,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
name_mapped, params_dict, loaded_weight, name_mapped, params_dict, loaded_weight,
shard_id, num_experts) shard_id, num_experts)
else: else:
if is_pp_missing_parameter(name_mapped, self):
continue
# Skip loading extra parameters for GPTQ/modelopt models # Skip loading extra parameters for GPTQ/modelopt models
if name_mapped.endswith( if name_mapped.endswith(
ignore_suffixes ignore_suffixes
@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Qwen3_VisionTransformer( if not multimodal_config.get_limit_per_prompt("image") and \
config.vision_config, not multimodal_config.get_limit_per_prompt("video"):
norm_eps=getattr(config, "rms_norm_eps", 1e-6), self.visual = None
quant_config=quant_config, else:
prefix=maybe_prefix(prefix, "visual"), self.visual = Qwen3_VisionTransformer(
use_data_parallel=self.use_data_parallel, config.vision_config,
) norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config, self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
prefix=maybe_prefix( prefix=maybe_prefix(
@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
config.vision_config.deepstack_visual_indexes config.vision_config.deepstack_visual_indexes
) if self.use_deepstack else 0 ) if self.use_deepstack else 0
# register buffer for deepstack # register buffer for deepstack
self.deepstack_input_embeds = [ if self.use_deepstack and self.visual is not None:
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens, self.deepstack_input_embeds = [
config.text_config.hidden_size) torch.zeros(
for _ in range(self.deepstack_num_level) vllm_config.scheduler_config.max_num_batched_tokens,
] if self.use_deepstack else None config.text_config.hidden_size)
for _ in range(self.deepstack_num_level)
]
else:
self.deepstack_input_embeds = None
self.visual_dim = config.vision_config.out_hidden_size self.visual_dim = config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level self.multiscale_dim = self.visual_dim * self.deepstack_num_level

View File

@ -70,6 +70,7 @@ _TEXT_GENERATION_MODELS = {
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),

View File

@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available) is_uva_available)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size return hf_config.hidden_size
text_config = hf_config.get_text_config() text_config = hf_config.get_text_config()
return text_config.hidden_size return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)

View File

@ -50,6 +50,7 @@ class MediaConnector:
connection: HTTPConnection = global_http_connection, connection: HTTPConnection = global_http_connection,
*, *,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
) -> None: ) -> None:
""" """
Args: Args:
@ -82,6 +83,9 @@ class MediaConnector:
allowed_local_media_path_ = None allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_ self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
def _load_data_url( def _load_data_url(
self, self,
@ -115,6 +119,14 @@ class MediaConnector:
return media_io.load_file(filepath) return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
if self.allowed_media_domains and url_spec.hostname not in \
self.allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}")
def load_from_url( def load_from_url(
self, self,
url: str, url: str,
@ -125,8 +137,14 @@ class MediaConnector:
url_spec = urlparse(url) url_spec = urlparse(url)
if url_spec.scheme.startswith("http"): if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout) data = connection.get_bytes(
url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
return media_io.load_bytes(data) return media_io.load_bytes(data)
@ -150,8 +168,14 @@ class MediaConnector:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if url_spec.scheme.startswith("http"): if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout) data = await connection.async_get_bytes(
url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
future = loop.run_in_executor(global_thread_pool, future = loop.run_in_executor(global_thread_pool,
media_io.load_bytes, data) media_io.load_bytes, data)
return await future return await future

View File

@ -93,11 +93,14 @@ class CpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str: has_sink: bool, use_sparse: bool) -> str:
if selected_backend and selected_backend != _Backend.TORCH_SDPA: if selected_backend and selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla: if use_mla:
raise NotImplementedError("MLA is not supported on CPU.") raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on CPU.")
logger.info("Using Torch SDPA backend.") logger.info("Using Torch SDPA backend.")
if not use_v1: if not use_v1:
raise ValueError("CPU backend only supports V1.") raise ValueError("CPU backend only supports V1.")

View File

@ -129,6 +129,8 @@ class CudaPlatformBase(Platform):
# TODO(lucas): handle this more gracefully # TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing # Note: model_config may be None during testing
if model_config is not None and model_config.use_mla: if model_config is not None and model_config.use_mla:
use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs, # then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the # else we default to CutlassMLA. For each case, we force the
@ -175,6 +177,12 @@ class CudaPlatformBase(Platform):
"Forcing kv cache block size to 64 for FlashInferMLA " "Forcing kv cache block size to 64 for FlashInferMLA "
"backend.") "backend.")
# TODO(Chen): remove this hacky code
if use_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse "
"backend.")
# lazy import to avoid circular import # lazy import to avoid circular import
from vllm.config import CUDAGraphMode from vllm.config import CUDAGraphMode
@ -205,6 +213,12 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_vit_attn_backend(cls, head_size: int, def get_vit_attn_backend(cls, head_size: int,
dtype: torch.dtype) -> _Backend: dtype: torch.dtype) -> _Backend:
# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
if cls.has_device_capability(100):
return _Backend.TORCH_SDPA
if dtype not in (torch.float16, torch.bfloat16): if dtype not in (torch.float16, torch.bfloat16):
return _Backend.XFORMERS return _Backend.XFORMERS
@ -225,7 +239,7 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str: has_sink, use_sparse) -> str:
if use_mla: if use_mla:
if not use_v1: if not use_v1:
raise RuntimeError( raise RuntimeError(
@ -235,6 +249,11 @@ class CudaPlatformBase(Platform):
from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.ops.flashmla import is_flashmla_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla from vllm.attention.utils.fa_utils import flash_attn_supports_mla
if use_sparse:
logger.info_once("Using Sparse MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla.flashmla_sparse."
"FlashMLASparseBackend")
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None and cls.is_device_capability(100) selected_backend is None and cls.is_device_capability(100)
and block_size == 128) and block_size == 128)

View File

@ -194,7 +194,7 @@ class Platform:
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink: bool) -> str: has_sink: bool, use_sparse: bool) -> str:
"""Get the attention backend class of a device.""" """Get the attention backend class of a device."""
return "" return ""

View File

@ -195,7 +195,10 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla, kv_cache_dtype, block_size, use_v1, use_mla,
has_sink) -> str: has_sink, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on ROCm.")
if use_mla: if use_mla:
if not use_v1: if not use_v1:
raise RuntimeError( raise RuntimeError(

View File

@ -49,7 +49,10 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str], dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool, use_mla: bool, block_size: int, use_v1: bool, use_mla: bool,
has_sink) -> str: has_sink, use_sparse) -> str:
if use_sparse:
raise NotImplementedError(
"Sparse Attention is not supported on TPU.")
if selected_backend != _Backend.PALLAS: if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend) logger.info("Cannot use %s backend on TPU.", selected_backend)

Some files were not shown because too many files have changed in this diff Show More