mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-11-04 17:34:34 +08:00 
			
		
		
		
	Compare commits
	
		
			33 Commits
		
	
	
		
			v0.11.0rc1
			...
			v0.11.0rc6
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| f71952c1c4 | |||
| d1007767c5 | |||
| c75c2e70d6 | |||
| 9d9a2b77f1 | |||
| 6040e0b6c0 | |||
| 05bf0c52a1 | |||
| c536881a7c | |||
| ebce361c07 | |||
| e4beabd2c8 | |||
| febb688356 | |||
| a1825fe645 | |||
| bab9231bf1 | |||
| c214d699fd | |||
| c3dfb0f6dd | |||
| 83f3c9beae | |||
| d0b178cef1 | |||
| b3230e1ac0 | |||
| 03df0fb5d2 | |||
| 9471879bd4 | |||
| ab5b6459df | |||
| 8ce5d3198d | |||
| 09c2cbc04a | |||
| 4c347044c9 | |||
| 19e7ab7315 | |||
| 6de3d431d9 | |||
| b14773bd64 | |||
| 26a7a33b88 | |||
| 5aa5811a16 | |||
| c2fa2d4dc9 | |||
| 32335c8b34 | |||
| 04c2b26972 | |||
| ee10d7e6ff | |||
| bb79c4da2f | 
@ -76,7 +76,7 @@ steps:
 | 
			
		||||
      queue: arm64_cpu_queue_postmerge
 | 
			
		||||
    commands:
 | 
			
		||||
      - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
 | 
			
		||||
      - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
 | 
			
		||||
      - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
 | 
			
		||||
      - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
 | 
			
		||||
 | 
			
		||||
  # Add job to create multi-arch manifest
 | 
			
		||||
 | 
			
		||||
@ -584,8 +584,9 @@ def main(args: argparse.Namespace):
 | 
			
		||||
        topk = config.num_experts_per_tok
 | 
			
		||||
        intermediate_size = config.intermediate_size
 | 
			
		||||
    elif config.architectures[0] in (
 | 
			
		||||
        "DeepseekV3ForCausalLM",
 | 
			
		||||
        "DeepseekV2ForCausalLM",
 | 
			
		||||
        "DeepseekV3ForCausalLM",
 | 
			
		||||
        "DeepseekV32ForCausalLM",
 | 
			
		||||
        "Glm4MoeForCausalLM",
 | 
			
		||||
    ):
 | 
			
		||||
        E = config.n_routed_experts
 | 
			
		||||
 | 
			
		||||
@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
 | 
			
		||||
else()
 | 
			
		||||
  FetchContent_Declare(
 | 
			
		||||
        flashmla
 | 
			
		||||
        GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
 | 
			
		||||
        GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
 | 
			
		||||
        GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
 | 
			
		||||
        GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
 | 
			
		||||
        GIT_PROGRESS TRUE
 | 
			
		||||
        CONFIGURE_COMMAND ""
 | 
			
		||||
        BUILD_COMMAND ""
 | 
			
		||||
@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
 | 
			
		||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
 | 
			
		||||
# Only build FlashMLA kernels if we are building for something compatible with 
 | 
			
		||||
# sm90a
 | 
			
		||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
 | 
			
		||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
 | 
			
		||||
 | 
			
		||||
set(SUPPORT_ARCHS)
 | 
			
		||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
 | 
			
		||||
    list(APPEND SUPPORT_ARCHS 9.0a)
 | 
			
		||||
endif()
 | 
			
		||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
 | 
			
		||||
    list(APPEND SUPPORT_ARCHS 10.0a)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
 | 
			
		||||
if(FLASH_MLA_ARCHS)
 | 
			
		||||
    set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
 | 
			
		||||
    list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
 | 
			
		||||
 | 
			
		||||
    set(FlashMLA_SOURCES
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/pybind.cpp
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    set(FlashMLA_Extension_SOURCES
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    set(FlashMLA_INCLUDES
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm90
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/cutlass/include
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc)
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    set(FlashMLA_Extension_INCLUDES
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/sm90
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/cutlass/include
 | 
			
		||||
        ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    set_gencode_flags_for_srcs(
 | 
			
		||||
        SRCS "${FlashMLA_SOURCES}"
 | 
			
		||||
        CUDA_ARCHS "${FLASH_MLA_ARCHS}")
 | 
			
		||||
 | 
			
		||||
    set_gencode_flags_for_srcs(
 | 
			
		||||
        SRCS "${FlashMLA_Extension_SOURCES}"
 | 
			
		||||
        CUDA_ARCHS "${FLASH_MLA_ARCHS}")
 | 
			
		||||
 | 
			
		||||
    define_gpu_extension_target(
 | 
			
		||||
        _flashmla_C
 | 
			
		||||
        DESTINATION vllm
 | 
			
		||||
@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
 | 
			
		||||
        INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
 | 
			
		||||
        USE_SABI 3
 | 
			
		||||
        WITH_SOABI)
 | 
			
		||||
 | 
			
		||||
    # Keep Stable ABI for the module, but *not* for CUDA/C++ files.
 | 
			
		||||
    # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
 | 
			
		||||
    target_compile_options(_flashmla_C PRIVATE
 | 
			
		||||
        $<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
 | 
			
		||||
        $<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
 | 
			
		||||
 | 
			
		||||
    define_gpu_extension_target(
 | 
			
		||||
        _flashmla_extension_C
 | 
			
		||||
        DESTINATION vllm
 | 
			
		||||
        LANGUAGE ${VLLM_GPU_LANG}
 | 
			
		||||
        SOURCES ${FlashMLA_Extension_SOURCES}
 | 
			
		||||
        COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
 | 
			
		||||
        ARCHITECTURES ${VLLM_GPU_ARCHES}
 | 
			
		||||
        INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
 | 
			
		||||
        USE_SABI 3
 | 
			
		||||
        WITH_SOABI)
 | 
			
		||||
 | 
			
		||||
    # Keep Stable ABI for the module, but *not* for CUDA/C++ files.
 | 
			
		||||
    # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
 | 
			
		||||
    target_compile_options(_flashmla_extension_C PRIVATE
 | 
			
		||||
        $<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
 | 
			
		||||
        $<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
 | 
			
		||||
else()
 | 
			
		||||
    # Create an empty target for setup.py when not targeting sm90a systems
 | 
			
		||||
    # Create empty targets for setup.py when not targeting sm90a systems
 | 
			
		||||
    add_custom_target(_flashmla_C)
 | 
			
		||||
    add_custom_target(_flashmla_extension_C)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
      for (; tile_scheduler.is_valid(); ++tile_scheduler) {
 | 
			
		||||
        auto blk_coord = tile_scheduler.get_block_coord();
 | 
			
		||||
        auto problem_shape = params.problem_shape;
 | 
			
		||||
	auto local_split_kv = params.split_kv;
 | 
			
		||||
        auto local_split_kv = params.split_kv;
 | 
			
		||||
        if (params.mainloop.ptr_seq != nullptr) {
 | 
			
		||||
          get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
 | 
			
		||||
	  if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
          if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
            local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
	if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
	  continue;
 | 
			
		||||
        if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
          continue;
 | 
			
		||||
        load_page_table(
 | 
			
		||||
          blk_coord,
 | 
			
		||||
          problem_shape,
 | 
			
		||||
          params.mainloop,
 | 
			
		||||
          shared_storage.tensors,
 | 
			
		||||
          pipeline_page_table, pipeline_pt_producer_state,
 | 
			
		||||
	  local_split_kv
 | 
			
		||||
          local_split_kv
 | 
			
		||||
        );
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
        CUTLASS_PRAGMA_NO_UNROLL
 | 
			
		||||
        for (; tile_scheduler.is_valid(); ++tile_scheduler) {
 | 
			
		||||
          auto blk_coord = tile_scheduler.get_block_coord();
 | 
			
		||||
	  auto problem_shape = params.problem_shape;
 | 
			
		||||
	  auto local_split_kv = params.split_kv;
 | 
			
		||||
          auto problem_shape = params.problem_shape;
 | 
			
		||||
          auto local_split_kv = params.split_kv;
 | 
			
		||||
          if (params.mainloop.ptr_seq != nullptr) {
 | 
			
		||||
            get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
 | 
			
		||||
	    if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
            if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
              local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
	  if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
          if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
            continue;
 | 
			
		||||
          load_cpasync(
 | 
			
		||||
            blk_coord,
 | 
			
		||||
@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
            params.mainloop_params,
 | 
			
		||||
            shared_storage.tensors,
 | 
			
		||||
            pipeline_load_qk, pipeline_load_qk_producer_state,
 | 
			
		||||
	    local_split_kv,
 | 
			
		||||
            local_split_kv,
 | 
			
		||||
            /* must be shared pipe */
 | 
			
		||||
            pipeline_page_table, pipeline_pt_consumer_state
 | 
			
		||||
          );
 | 
			
		||||
@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
          CUTLASS_PRAGMA_NO_UNROLL
 | 
			
		||||
          for (; tile_scheduler.is_valid(); ++tile_scheduler) {
 | 
			
		||||
            auto blk_coord = tile_scheduler.get_block_coord();
 | 
			
		||||
	    auto problem_shape = params.problem_shape;
 | 
			
		||||
	    auto local_split_kv = params.split_kv;
 | 
			
		||||
            auto problem_shape = params.problem_shape;
 | 
			
		||||
            auto local_split_kv = params.split_kv;
 | 
			
		||||
            if (params.mainloop.ptr_seq != nullptr) {
 | 
			
		||||
              get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
 | 
			
		||||
	      if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
	        local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
 | 
			
		||||
	      }
 | 
			
		||||
              if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
                local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
	    if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
            if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
              continue;
 | 
			
		||||
            load_tma</* paged= */ true>(
 | 
			
		||||
              blk_coord,
 | 
			
		||||
@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
              shared_storage.tensors,
 | 
			
		||||
              pipeline_load_qk, pipeline_load_qk_producer_state,
 | 
			
		||||
              pipeline_load_qk, pipeline_load_qk_producer_state,
 | 
			
		||||
	      local_split_kv
 | 
			
		||||
              local_split_kv
 | 
			
		||||
            );
 | 
			
		||||
            cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
 | 
			
		||||
          }
 | 
			
		||||
@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
          CUTLASS_PRAGMA_NO_UNROLL
 | 
			
		||||
          for (; tile_scheduler.is_valid(); ++tile_scheduler) {
 | 
			
		||||
            auto blk_coord = tile_scheduler.get_block_coord();
 | 
			
		||||
	    auto problem_shape = params.problem_shape;
 | 
			
		||||
	    auto local_split_kv = params.split_kv;
 | 
			
		||||
            auto problem_shape = params.problem_shape;
 | 
			
		||||
            auto local_split_kv = params.split_kv;
 | 
			
		||||
            if (params.mainloop.ptr_seq != nullptr) {
 | 
			
		||||
              get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
 | 
			
		||||
	      if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
              if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
                local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
 | 
			
		||||
	      }
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
	    if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
            if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
              continue;
 | 
			
		||||
            load_tma<false>(
 | 
			
		||||
              blk_coord,
 | 
			
		||||
@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
              shared_storage.tensors,
 | 
			
		||||
              pipeline_load_qk, pipeline_load_qk_producer_state,
 | 
			
		||||
              pipeline_load_qk, pipeline_load_qk_producer_state,
 | 
			
		||||
	      local_split_kv
 | 
			
		||||
              local_split_kv
 | 
			
		||||
            );
 | 
			
		||||
            cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
 | 
			
		||||
          }
 | 
			
		||||
@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
        for (; tile_scheduler.is_valid(); ++tile_scheduler) {
 | 
			
		||||
          auto blk_coord = tile_scheduler.get_block_coord();
 | 
			
		||||
          auto problem_shape = params.problem_shape;
 | 
			
		||||
	  auto local_split_kv = params.split_kv;
 | 
			
		||||
          auto local_split_kv = params.split_kv;
 | 
			
		||||
          if (params.mainloop.ptr_seq != nullptr) {
 | 
			
		||||
            get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
 | 
			
		||||
            if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
                local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
	  if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
          if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
            continue;
 | 
			
		||||
          mma(blk_coord,
 | 
			
		||||
            problem_shape,
 | 
			
		||||
@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
            pipeline_mma_s, pipeline_mma_s_producer_state,
 | 
			
		||||
            pipeline_p_mma, pipeline_p_mma_consumer_state,
 | 
			
		||||
            pipeline_mma_o, pipeline_mma_o_producer_state,
 | 
			
		||||
	    local_split_kv
 | 
			
		||||
            local_split_kv
 | 
			
		||||
          );
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
      for (; tile_scheduler.is_valid(); ++tile_scheduler) {
 | 
			
		||||
        auto blk_coord = tile_scheduler.get_block_coord();
 | 
			
		||||
        auto problem_shape = params.problem_shape;
 | 
			
		||||
	auto split_kv = params.split_kv;
 | 
			
		||||
	auto local_split_kv = split_kv;
 | 
			
		||||
        auto split_kv = params.split_kv;
 | 
			
		||||
        auto local_split_kv = split_kv;
 | 
			
		||||
        if (params.mainloop.ptr_seq != nullptr) {
 | 
			
		||||
          get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
 | 
			
		||||
	  if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
          if (params.ptr_split_kv != nullptr) {
 | 
			
		||||
            local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
	if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
        if (local_split_kv <= get<3>(blk_coord))
 | 
			
		||||
          continue;
 | 
			
		||||
        compute(
 | 
			
		||||
          blk_coord,
 | 
			
		||||
@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
          pipeline_mma_s, pipeline_mma_s_consumer_state,
 | 
			
		||||
          pipeline_p_mma, pipeline_p_mma_producer_state,
 | 
			
		||||
          pipeline_mma_o, pipeline_mma_o_consumer_state,
 | 
			
		||||
	  local_split_kv
 | 
			
		||||
          local_split_kv
 | 
			
		||||
        );
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
 | 
			
		||||
      cutlass::arch::NamedBarrier(
 | 
			
		||||
          (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
 | 
			
		||||
          kNamedBarrierEpilogue
 | 
			
		||||
      ).arrive();
 | 
			
		||||
      ).arrive_and_wait();
 | 
			
		||||
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -56,3 +56,11 @@ void cp_gather_cache(
 | 
			
		||||
    torch::Tensor const& block_table,  // [BATCH, BLOCK_INDICES]
 | 
			
		||||
    torch::Tensor const& cu_seq_lens,  // [BATCH+1]
 | 
			
		||||
    int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
 | 
			
		||||
 | 
			
		||||
// Indexer K quantization and cache function
 | 
			
		||||
void indexer_k_quant_and_cache(
 | 
			
		||||
    torch::Tensor& k,             // [num_tokens, head_dim]
 | 
			
		||||
    torch::Tensor& kv_cache,      // [num_blocks, block_size, cache_stride]
 | 
			
		||||
    torch::Tensor& slot_mapping,  // [num_tokens]
 | 
			
		||||
    int64_t quant_block_size,     // quantization block size
 | 
			
		||||
    const std::string& scale_fmt);
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cassert>
 | 
			
		||||
#include <cfloat>  // FLT_MIN
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@ -396,6 +397,180 @@ __global__ void concat_and_cache_mla_kernel(
 | 
			
		||||
  copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
 | 
			
		||||
__global__ void concat_and_cache_ds_mla_kernel(
 | 
			
		||||
    const scalar_t* __restrict__ kv_c,  // [num_tokens, kv_lora_rank]
 | 
			
		||||
    const scalar_t* __restrict__ k_pe,  // [num_tokens, pe_dim]
 | 
			
		||||
    cache_t* __restrict__ kv_cache,  // [num_blocks, block_size, (kv_lora_rank
 | 
			
		||||
                                     // + pe_dim)]
 | 
			
		||||
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
 | 
			
		||||
    const int block_stride,                    //
 | 
			
		||||
    const int entry_stride,                    //
 | 
			
		||||
    const int kv_c_stride,                     //
 | 
			
		||||
    const int k_pe_stride,                     //
 | 
			
		||||
    const int kv_lora_rank,                    //
 | 
			
		||||
    const int pe_dim,                          //
 | 
			
		||||
    const int block_size,                      //
 | 
			
		||||
    const float* scale                         //
 | 
			
		||||
) {
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  const int64_t slot_idx = slot_mapping[token_idx];
 | 
			
		||||
  // NOTE: slot_idx can be -1 if the token is padded
 | 
			
		||||
  if (slot_idx < 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  const int64_t block_idx = slot_idx / block_size;
 | 
			
		||||
  const int64_t block_offset = slot_idx % block_size;
 | 
			
		||||
  const int64_t dst_idx_start =
 | 
			
		||||
      block_idx * block_stride + block_offset * entry_stride;
 | 
			
		||||
 | 
			
		||||
  // Create 4 tile scales in shared memory
 | 
			
		||||
  __shared__ float smem[20];
 | 
			
		||||
  float* shard_abs_max = smem;
 | 
			
		||||
  float* tile_scales = smem + 16;
 | 
			
		||||
 | 
			
		||||
  // For the NoPE part, each tile of 128 elements is handled by 4 warps
 | 
			
		||||
  // (128 threads). There are 4 total tiles, so 16 warps (512 threads).
 | 
			
		||||
  // The first thread of the first warp in each tile writes the scale
 | 
			
		||||
  // value for the tile. The RoPE part (last 64 elements) is handled
 | 
			
		||||
  // by another 2 warps (64 threads).
 | 
			
		||||
  // So in total, we use 18 warps (576 threads) per block.
 | 
			
		||||
 | 
			
		||||
  // Cast kv_cache to 16_bit for RoPE values
 | 
			
		||||
  scalar_t* kv_cache_16bit =
 | 
			
		||||
      reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
 | 
			
		||||
 | 
			
		||||
  // The last 64 threads handle the RoPE part
 | 
			
		||||
  if (threadIdx.x >= kv_lora_rank) {
 | 
			
		||||
    const int8_t pe_idx = threadIdx.x - kv_lora_rank;
 | 
			
		||||
    const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
 | 
			
		||||
    // RoPE values start after the packed 8-bit NoPE values and the
 | 
			
		||||
    // 32-bit scales
 | 
			
		||||
    const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
 | 
			
		||||
    kv_cache_16bit[dst_idx] = k_pe[src_idx];
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Determine the scale for each chunk of NoPE
 | 
			
		||||
  const int16_t tile_idx = threadIdx.x >> 7;
 | 
			
		||||
  const int16_t warp_idx = (threadIdx.x & 127) >> 5;
 | 
			
		||||
  const int16_t lane_idx = threadIdx.x & 31;
 | 
			
		||||
 | 
			
		||||
  // Load the NoPE element for this thread into registers
 | 
			
		||||
  const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
 | 
			
		||||
  const scalar_t src_val = kv_c[src_idx];
 | 
			
		||||
 | 
			
		||||
  // Warp-level reduction to find the max absolute value in the warp
 | 
			
		||||
  float max_abs = fabsf(src_val);
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int offset = 16; offset > 0; offset /= 2) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
 | 
			
		||||
#else
 | 
			
		||||
    max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // The first lane of each warp in each tile writes the max_abs of this part
 | 
			
		||||
  // of the tile to shared memory
 | 
			
		||||
  if (lane_idx == 0) {
 | 
			
		||||
    shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  // The first lane of the first warp in each tile computes the scale for the
 | 
			
		||||
  // tile and writes it to shared memory and to kv_cache
 | 
			
		||||
  if (warp_idx == 0 && lane_idx == 0) {
 | 
			
		||||
    float4 shard_abs_max_vec =
 | 
			
		||||
        reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
 | 
			
		||||
    float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
 | 
			
		||||
                             fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
 | 
			
		||||
                       448.f;
 | 
			
		||||
 | 
			
		||||
    // Avoid division by zero in `scaled_convert`
 | 
			
		||||
    tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
 | 
			
		||||
    float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
 | 
			
		||||
    const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
 | 
			
		||||
    kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  // Now all threads in the block scale and write their element
 | 
			
		||||
  const float scale_val = tile_scales[tile_idx];
 | 
			
		||||
  const int64_t dst_idx = dst_idx_start + threadIdx.x;
 | 
			
		||||
  kv_cache[dst_idx] =
 | 
			
		||||
      fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
 | 
			
		||||
          src_val, scale_val);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
 | 
			
		||||
__global__ void indexer_k_quant_and_cache_kernel(
 | 
			
		||||
    const scalar_t* __restrict__ k,  // [num_tokens, head_dim]
 | 
			
		||||
    cache_t* __restrict__ kv_cache,  // [num_blocks, block_size, cache_stride]
 | 
			
		||||
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
 | 
			
		||||
    const int head_dim,                        // dimension of each head
 | 
			
		||||
    const int quant_block_size,                // quantization block size
 | 
			
		||||
    const int cache_block_size,                // cache block size
 | 
			
		||||
    const int cache_stride,  // stride for each token in kv_cache
 | 
			
		||||
    const bool use_ue8m0     // use ue8m0 scale format
 | 
			
		||||
) {
 | 
			
		||||
  constexpr int VEC_SIZE = 4;
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
 | 
			
		||||
                                threadIdx.y * blockDim.x + threadIdx.x) *
 | 
			
		||||
                               VEC_SIZE;
 | 
			
		||||
  const int64_t slot_idx = slot_mapping[token_idx];
 | 
			
		||||
  const int64_t block_idx = slot_idx / cache_block_size;
 | 
			
		||||
  const int64_t block_offset = slot_idx % cache_block_size;
 | 
			
		||||
 | 
			
		||||
  // NOTE: slot_idx can be -1 if the token is padded
 | 
			
		||||
  if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  float2 k_val = (reinterpret_cast<const float2*>(
 | 
			
		||||
      k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
 | 
			
		||||
  scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
 | 
			
		||||
  float amax = 0.0f;
 | 
			
		||||
  for (int i = 0; i < VEC_SIZE; i++) {
 | 
			
		||||
    amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
 | 
			
		||||
  }
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  __syncwarp();
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Reduced amax
 | 
			
		||||
  for (int mask = 16; mask > 0; mask /= 2) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
 | 
			
		||||
#else
 | 
			
		||||
    amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  __syncwarp();
 | 
			
		||||
#endif
 | 
			
		||||
  float scale = fmaxf(amax, 1e-4) / 448.0f;
 | 
			
		||||
  if (use_ue8m0) {
 | 
			
		||||
    scale = exp2f(ceilf(log2f(scale)));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
 | 
			
		||||
                             block_offset * head_dim + head_dim_idx;
 | 
			
		||||
  for (int i = 0; i < VEC_SIZE; i++) {
 | 
			
		||||
    kv_cache[dst_offset + i] =
 | 
			
		||||
        fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
 | 
			
		||||
  }
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
    const int64_t dst_scale_idx =
 | 
			
		||||
        block_idx * cache_block_size * cache_stride +
 | 
			
		||||
        cache_block_size * head_dim +
 | 
			
		||||
        (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
 | 
			
		||||
    reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
// KV_T is the data type of key and value tensors.
 | 
			
		||||
@ -438,7 +613,7 @@ void reshape_and_cache(
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
 | 
			
		||||
                             CALL_RESHAPE_AND_CACHE)
 | 
			
		||||
                             CALL_RESHAPE_AND_CACHE);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// KV_T is the data type of key and value tensors.
 | 
			
		||||
@ -509,6 +684,18 @@ void reshape_and_cache_flash(
 | 
			
		||||
          kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size,   \
 | 
			
		||||
          reinterpret_cast<const float*>(scale.data_ptr()));
 | 
			
		||||
 | 
			
		||||
// KV_T is the data type of key and value tensors.
 | 
			
		||||
// CACHE_T is the stored data type of kv-cache.
 | 
			
		||||
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE)           \
 | 
			
		||||
  vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE>         \
 | 
			
		||||
      <<<grid, block, 0, stream>>>(                                     \
 | 
			
		||||
          reinterpret_cast<KV_T*>(kv_c.data_ptr()),                     \
 | 
			
		||||
          reinterpret_cast<KV_T*>(k_pe.data_ptr()),                     \
 | 
			
		||||
          reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()),              \
 | 
			
		||||
          slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
 | 
			
		||||
          kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size,   \
 | 
			
		||||
          reinterpret_cast<const float*>(scale.data_ptr()));
 | 
			
		||||
 | 
			
		||||
void concat_and_cache_mla(
 | 
			
		||||
    torch::Tensor& kv_c,          // [num_tokens, kv_lora_rank]
 | 
			
		||||
    torch::Tensor& k_pe,          // [num_tokens, pe_dim]
 | 
			
		||||
@ -531,20 +718,44 @@ void concat_and_cache_mla(
 | 
			
		||||
  int pe_dim = k_pe.size(1);
 | 
			
		||||
  int block_size = kv_cache.size(1);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
 | 
			
		||||
  if (kv_cache_dtype == "fp8_ds_mla") {
 | 
			
		||||
    TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
 | 
			
		||||
    TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
 | 
			
		||||
    TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
 | 
			
		||||
                "kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
 | 
			
		||||
    TORCH_CHECK(kv_c.itemsize() == 2,
 | 
			
		||||
                "kv_c.itemsize() must be 2 for fp8_ds_mla");
 | 
			
		||||
    TORCH_CHECK(k_pe.itemsize() == 2,
 | 
			
		||||
                "k_pe.itemsize() must be 2 for fp8_ds_mla");
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int kv_c_stride = kv_c.stride(0);
 | 
			
		||||
  int k_pe_stride = k_pe.stride(0);
 | 
			
		||||
  int block_stride = kv_cache.stride(0);
 | 
			
		||||
  int entry_stride = kv_cache.stride(1);
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(kv_lora_rank, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
 | 
			
		||||
                             CALL_CONCAT_AND_CACHE_MLA);
 | 
			
		||||
  if (kv_cache_dtype == "fp8_ds_mla") {
 | 
			
		||||
    dim3 grid(num_tokens);
 | 
			
		||||
    // For the NoPE part, each tile of 128 elements is handled by 4 warps
 | 
			
		||||
    // (128 threads). There are 4 total tiles, so 16 warps (512 threads).
 | 
			
		||||
    // The first thread of the first warp in each tile writes the scale
 | 
			
		||||
    // value for the tile. The RoPE part (last 64 elements) is handled
 | 
			
		||||
    // by another 2 warps (64 threads).
 | 
			
		||||
    // So in total, we use 18 warps (576 threads) per block.
 | 
			
		||||
    dim3 block(576);
 | 
			
		||||
    DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
 | 
			
		||||
                               CALL_CONCAT_AND_CACHE_DS_MLA);
 | 
			
		||||
  } else {
 | 
			
		||||
    dim3 grid(num_tokens);
 | 
			
		||||
    dim3 block(std::min(kv_lora_rank, 512));
 | 
			
		||||
    DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
 | 
			
		||||
                               CALL_CONCAT_AND_CACHE_MLA);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
@ -922,3 +1133,42 @@ void cp_gather_cache(
 | 
			
		||||
    TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Macro to dispatch the kernel based on the data type.
 | 
			
		||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)         \
 | 
			
		||||
  vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE>       \
 | 
			
		||||
      <<<grid, block, 0, stream>>>(                                     \
 | 
			
		||||
          reinterpret_cast<KV_T*>(k.data_ptr()),                        \
 | 
			
		||||
          reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()),              \
 | 
			
		||||
          slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
 | 
			
		||||
          cache_block_size, cache_stride, use_ue8m0);
 | 
			
		||||
 | 
			
		||||
void indexer_k_quant_and_cache(
 | 
			
		||||
    torch::Tensor& k,             // [num_tokens, head_dim]
 | 
			
		||||
    torch::Tensor& kv_cache,      // [num_blocks, block_size, cache_stride]
 | 
			
		||||
    torch::Tensor& slot_mapping,  // [num_tokens]
 | 
			
		||||
    int64_t quant_block_size,     // quantization block size
 | 
			
		||||
    const std::string& scale_fmt) {
 | 
			
		||||
  int num_tokens = k.size(0);
 | 
			
		||||
  int head_dim = k.size(1);
 | 
			
		||||
  int cache_block_size = kv_cache.size(1);
 | 
			
		||||
  int cache_stride = kv_cache.size(2);
 | 
			
		||||
  bool use_ue8m0 = scale_fmt == "ue8m0";
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(k.device() == kv_cache.device(),
 | 
			
		||||
              "k and kv_cache must be on the same device");
 | 
			
		||||
  TORCH_CHECK(k.device() == slot_mapping.device(),
 | 
			
		||||
              "k and slot_mapping must be on the same device");
 | 
			
		||||
  TORCH_CHECK(head_dim % quant_block_size == 0,
 | 
			
		||||
              "head_dim must be divisible by quant_block_size");
 | 
			
		||||
 | 
			
		||||
  constexpr int vec_size = 4;
 | 
			
		||||
  dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
 | 
			
		||||
                            (quant_block_size * vec_size));
 | 
			
		||||
  dim3 block(32, vec_size);
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
 | 
			
		||||
                             CALL_INDEXER_K_QUANT_AND_CACHE);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
 | 
			
		||||
          TORCH_CHECK(false,                                                   \
 | 
			
		||||
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
 | 
			
		||||
        }                                                                      \
 | 
			
		||||
      } else if (KV_DTYPE == "fp8_ds_mla") {                                   \
 | 
			
		||||
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
 | 
			
		||||
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
 | 
			
		||||
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
 | 
			
		||||
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
 | 
			
		||||
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
 | 
			
		||||
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
 | 
			
		||||
        } else {                                                               \
 | 
			
		||||
          TORCH_CHECK(false,                                                   \
 | 
			
		||||
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
 | 
			
		||||
        }                                                                      \
 | 
			
		||||
      } else {                                                                 \
 | 
			
		||||
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
 | 
			
		||||
      }                                                                        \
 | 
			
		||||
 | 
			
		||||
@ -713,6 +713,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
 | 
			
		||||
      "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
 | 
			
		||||
      "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
 | 
			
		||||
  cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
 | 
			
		||||
 | 
			
		||||
  cache_ops.def(
 | 
			
		||||
      "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
 | 
			
		||||
      "slot_mapping, "
 | 
			
		||||
      "int quant_block_size, str kv_cache_dtype) -> ()");
 | 
			
		||||
  cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
 | 
			
		||||
                 &indexer_k_quant_and_cache);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,11 @@ ARG PYTHON_VERSION=3.12
 | 
			
		||||
#
 | 
			
		||||
# Example:
 | 
			
		||||
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
 | 
			
		||||
 | 
			
		||||
# Important: We build with an old version of Ubuntu to maintain broad 
 | 
			
		||||
# compatibility with other Linux OSes. The main reason for this is that the
 | 
			
		||||
# glibc version is baked into the distro, and binaries built with one glibc
 | 
			
		||||
# version are not backwards compatible with OSes that use an earlier version.
 | 
			
		||||
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
 | 
			
		||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
 | 
			
		||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
 | 
			
		||||
@ -75,34 +80,19 @@ ARG TARGETPLATFORM
 | 
			
		||||
ARG INSTALL_KV_CONNECTORS=false
 | 
			
		||||
ENV DEBIAN_FRONTEND=noninteractive
 | 
			
		||||
 | 
			
		||||
ARG DEADSNAKES_MIRROR_URL
 | 
			
		||||
ARG DEADSNAKES_GPGKEY_URL
 | 
			
		||||
ARG GET_PIP_URL
 | 
			
		||||
 | 
			
		||||
# Install Python and other dependencies
 | 
			
		||||
# Install system dependencies and uv, then create Python virtual environment
 | 
			
		||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
 | 
			
		||||
    && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
 | 
			
		||||
    && apt-get update -y \
 | 
			
		||||
    && apt-get install -y ccache software-properties-common git curl sudo \
 | 
			
		||||
    && if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
 | 
			
		||||
        if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
 | 
			
		||||
            mkdir -p -m 0755 /etc/apt/keyrings ; \
 | 
			
		||||
            curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
 | 
			
		||||
            sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
 | 
			
		||||
            echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
 | 
			
		||||
        fi ; \
 | 
			
		||||
    else \
 | 
			
		||||
        for i in 1 2 3; do \
 | 
			
		||||
            add-apt-repository -y ppa:deadsnakes/ppa && break || \
 | 
			
		||||
            { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
 | 
			
		||||
        done ; \
 | 
			
		||||
    fi \
 | 
			
		||||
    && apt-get update -y \
 | 
			
		||||
    && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
 | 
			
		||||
    && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
 | 
			
		||||
    && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
 | 
			
		||||
    && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
 | 
			
		||||
    && curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
 | 
			
		||||
    && apt-get install -y ccache software-properties-common git curl sudo python3-pip \
 | 
			
		||||
    && curl -LsSf https://astral.sh/uv/install.sh | sh \
 | 
			
		||||
    && $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
 | 
			
		||||
    && rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
 | 
			
		||||
    && ln -s /opt/venv/bin/python3 /usr/bin/python3 \
 | 
			
		||||
    && ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
 | 
			
		||||
    && ln -s /opt/venv/bin/pip /usr/bin/pip \
 | 
			
		||||
    && python3 --version && python3 -m pip --version
 | 
			
		||||
 | 
			
		||||
ARG PIP_INDEX_URL UV_INDEX_URL
 | 
			
		||||
@ -111,9 +101,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL
 | 
			
		||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
 | 
			
		||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
 | 
			
		||||
 | 
			
		||||
# Install uv for faster pip installs
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    python3 -m pip install uv
 | 
			
		||||
# Activate virtual environment and add uv to PATH
 | 
			
		||||
ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
 | 
			
		||||
ENV VIRTUAL_ENV="/opt/venv"
 | 
			
		||||
 | 
			
		||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
 | 
			
		||||
# Reference: https://github.com/astral-sh/uv/pull/1694
 | 
			
		||||
@ -142,7 +132,7 @@ WORKDIR /workspace
 | 
			
		||||
COPY requirements/common.txt requirements/common.txt
 | 
			
		||||
COPY requirements/cuda.txt requirements/cuda.txt
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/cuda.txt \
 | 
			
		||||
    uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
 | 
			
		||||
    --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
 | 
			
		||||
 | 
			
		||||
# cuda arch list used by torch
 | 
			
		||||
@ -172,7 +162,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
ENV UV_LINK_MODE=copy
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/build.txt \
 | 
			
		||||
    uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
 | 
			
		||||
    --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
 | 
			
		||||
 | 
			
		||||
COPY . .
 | 
			
		||||
@ -269,7 +259,7 @@ COPY requirements/lint.txt requirements/lint.txt
 | 
			
		||||
COPY requirements/test.txt requirements/test.txt
 | 
			
		||||
COPY requirements/dev.txt requirements/dev.txt
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/dev.txt \
 | 
			
		||||
    uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
 | 
			
		||||
    --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
 | 
			
		||||
#################### DEV IMAGE ####################
 | 
			
		||||
 | 
			
		||||
@ -404,6 +394,9 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
 | 
			
		||||
                FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
 | 
			
		||||
            fi
 | 
			
		||||
            echo "🏗️  Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
 | 
			
		||||
            export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
 | 
			
		||||
            # HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
 | 
			
		||||
            uv pip install --system cuda-python==$(echo $CUDA_VERSION | cut -d. -f1,2) pynvml==$(echo $CUDA_VERSION | cut -d. -f1) nvidia-nvshmem-cu$(echo $CUDA_VERSION | cut -d. -f1)
 | 
			
		||||
            # Build AOT kernels
 | 
			
		||||
            TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
 | 
			
		||||
                python3 -m flashinfer.aot
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0
 | 
			
		||||
#
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
# prepare basic build environment
 | 
			
		||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
 | 
			
		||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
 | 
			
		||||
ARG CUDA_VERSION=12.8.0
 | 
			
		||||
ARG PYTHON_VERSION=3.12
 | 
			
		||||
ARG TARGETPLATFORM
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,13 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
 | 
			
		||||
    We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
 | 
			
		||||
    and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
 | 
			
		||||
 | 
			
		||||
!!! tip
 | 
			
		||||
    When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
 | 
			
		||||
 | 
			
		||||
    Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions.
 | 
			
		||||
 | 
			
		||||
    This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
 | 
			
		||||
 | 
			
		||||
## Offline Inference
 | 
			
		||||
 | 
			
		||||
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
 | 
			
		||||
 | 
			
		||||
@ -60,6 +60,15 @@ Key points from the PyTorch security guide:
 | 
			
		||||
- Implement proper authentication and authorization for management interfaces
 | 
			
		||||
- Follow the principle of least privilege for all system components
 | 
			
		||||
 | 
			
		||||
### 4. **Restrict Domains Access for Media URLs:**
 | 
			
		||||
 | 
			
		||||
Restrict domains that vLLM can access for media URLs by setting
 | 
			
		||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
 | 
			
		||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
 | 
			
		||||
 | 
			
		||||
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP
 | 
			
		||||
redirects from being followed to bypass domain restrictions.
 | 
			
		||||
 | 
			
		||||
## Security and Firewalls: Protecting Exposed vLLM Systems
 | 
			
		||||
 | 
			
		||||
While vLLM is designed to allow unsafe network services to be isolated to
 | 
			
		||||
 | 
			
		||||
@ -54,6 +54,7 @@ def parse_args():
 | 
			
		||||
        "--method",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="eagle",
 | 
			
		||||
        choices=["ngram", "eagle", "eagle3", "mtp"],
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--num-spec-tokens", type=int, default=2)
 | 
			
		||||
    parser.add_argument("--prompt-lookup-max", type=int, default=5)
 | 
			
		||||
@ -118,9 +119,9 @@ def main(args):
 | 
			
		||||
            "prompt_lookup_max": args.prompt_lookup_max,
 | 
			
		||||
            "prompt_lookup_min": args.prompt_lookup_min,
 | 
			
		||||
        }
 | 
			
		||||
    elif args.method.endswith("mtp"):
 | 
			
		||||
    elif args.method == "mtp":
 | 
			
		||||
        speculative_config = {
 | 
			
		||||
            "method": args.method,
 | 
			
		||||
            "method": "mtp",
 | 
			
		||||
            "num_speculative_tokens": args.num_spec_tokens,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								setup.py
									
									
									
									
									
								
							@ -322,6 +322,8 @@ class precompiled_wheel_utils:
 | 
			
		||||
                    "vllm/_C.abi3.so",
 | 
			
		||||
                    "vllm/_moe_C.abi3.so",
 | 
			
		||||
                    "vllm/_flashmla_C.abi3.so",
 | 
			
		||||
                    "vllm/_flashmla_extension_C.abi3.so",
 | 
			
		||||
                    "vllm/_sparse_flashmla_C.abi3.so",
 | 
			
		||||
                    "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
 | 
			
		||||
                    "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
 | 
			
		||||
                    "vllm/cumem_allocator.abi3.so",
 | 
			
		||||
@ -589,6 +591,8 @@ if _is_cuda():
 | 
			
		||||
        # not targeting a hopper system
 | 
			
		||||
        ext_modules.append(
 | 
			
		||||
            CMakeExtension(name="vllm._flashmla_C", optional=True))
 | 
			
		||||
        ext_modules.append(
 | 
			
		||||
            CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
 | 
			
		||||
    ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
 | 
			
		||||
 | 
			
		||||
if _build_custom_ops():
 | 
			
		||||
 | 
			
		||||
@ -191,7 +191,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
 | 
			
		||||
                num_kv_heads=self.num_kv_heads,
 | 
			
		||||
                head_size=self.head_size,
 | 
			
		||||
                dtype=self.kv_cache_dtype,
 | 
			
		||||
                use_mla=False,
 | 
			
		||||
            ),
 | 
			
		||||
            layer_names=[self.attn.layer_name],
 | 
			
		||||
            vllm_config=self.vllm_config,
 | 
			
		||||
 | 
			
		||||
@ -45,6 +45,7 @@ class MockModelConfig:
 | 
			
		||||
    logits_processor_pattern: Optional[str] = None
 | 
			
		||||
    diff_sampling_param: Optional[dict] = None
 | 
			
		||||
    allowed_local_media_path: str = ""
 | 
			
		||||
    allowed_media_domains: Optional[list[str]] = None
 | 
			
		||||
    encoder_config = None
 | 
			
		||||
    generation_config: str = "auto"
 | 
			
		||||
    skip_tokenizer_init: bool = False
 | 
			
		||||
 | 
			
		||||
@ -240,6 +240,7 @@ class MockModelConfig:
 | 
			
		||||
    logits_processor_pattern = None
 | 
			
		||||
    diff_sampling_param: Optional[dict] = None
 | 
			
		||||
    allowed_local_media_path: str = ""
 | 
			
		||||
    allowed_media_domains: Optional[list[str]] = None
 | 
			
		||||
    encoder_config = None
 | 
			
		||||
    generation_config: str = "auto"
 | 
			
		||||
    media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
 | 
			
		||||
                                         parse_chat_messages,
 | 
			
		||||
                                         parse_chat_messages_futures,
 | 
			
		||||
                                         resolve_chat_template_content_format,
 | 
			
		||||
                                         resolve_chat_template_kwargs,
 | 
			
		||||
                                         resolve_hf_chat_template)
 | 
			
		||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
 | 
			
		||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
 | 
			
		||||
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
 | 
			
		||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
 | 
			
		||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
 | 
			
		||||
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
 | 
			
		||||
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
 | 
			
		||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
 | 
			
		||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
 | 
			
		||||
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
 | 
			
		||||
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
 | 
			
		||||
    assert isinstance(chat_template, str)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "model, expected_kwargs",
 | 
			
		||||
    [
 | 
			
		||||
        (
 | 
			
		||||
            QWEN2VL_MODEL_ID,
 | 
			
		||||
            {
 | 
			
		||||
                "add_vision_id", "add_generation_prompt",
 | 
			
		||||
                "continue_final_message", "tools"
 | 
			
		||||
            },
 | 
			
		||||
        ),
 | 
			
		||||
        (
 | 
			
		||||
            QWEN3_MODEL_ID,
 | 
			
		||||
            {
 | 
			
		||||
                "enable_thinking", "add_generation_prompt",
 | 
			
		||||
                "continue_final_message", "tools"
 | 
			
		||||
            },
 | 
			
		||||
        ),
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
 | 
			
		||||
                                         expected_kwargs):
 | 
			
		||||
    """checks that chat_template is a dict type for HF models."""
 | 
			
		||||
    model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
 | 
			
		||||
    model_info.check_available_online(on_fail="skip")
 | 
			
		||||
 | 
			
		||||
    tools = ([{
 | 
			
		||||
        "type": "function",
 | 
			
		||||
        "function": {
 | 
			
		||||
            "name": "dummy_function_name",
 | 
			
		||||
            "description": "This is a dummy function",
 | 
			
		||||
            "parameters": sample_json_schema,
 | 
			
		||||
        },
 | 
			
		||||
    }])
 | 
			
		||||
 | 
			
		||||
    chat_template_kwargs = {
 | 
			
		||||
        # both unused
 | 
			
		||||
        "unsed_kwargs_1": 123,
 | 
			
		||||
        "unsed_kwargs_2": "abc",
 | 
			
		||||
        # should not appear
 | 
			
		||||
        "chat_template": "{% Hello world! %}",
 | 
			
		||||
        # used by tokenizer
 | 
			
		||||
        "continue_final_message": True,
 | 
			
		||||
        "tools": tools,
 | 
			
		||||
        # both used by Qwen2-VL and Qwen3
 | 
			
		||||
        "add_generation_prompt": True,
 | 
			
		||||
        # only used by Qwen2-VL
 | 
			
		||||
        "add_vision_id": True,
 | 
			
		||||
        # only used by Qwen3
 | 
			
		||||
        "enable_thinking": True,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    model_config = ModelConfig(
 | 
			
		||||
        model,
 | 
			
		||||
        tokenizer=model_info.tokenizer or model,
 | 
			
		||||
        tokenizer_mode=model_info.tokenizer_mode,
 | 
			
		||||
        revision=model_info.revision,
 | 
			
		||||
        trust_remote_code=model_info.trust_remote_code,
 | 
			
		||||
        hf_overrides=model_info.hf_overrides,
 | 
			
		||||
        skip_tokenizer_init=model_info.skip_tokenizer_init,
 | 
			
		||||
        enforce_eager=model_info.enforce_eager,
 | 
			
		||||
        dtype=model_info.dtype)
 | 
			
		||||
 | 
			
		||||
    # Build the tokenizer
 | 
			
		||||
    tokenizer = get_tokenizer(
 | 
			
		||||
        model,
 | 
			
		||||
        trust_remote_code=model_config.trust_remote_code,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Test detecting the tokenizer's chat_template
 | 
			
		||||
    chat_template = resolve_hf_chat_template(
 | 
			
		||||
        tokenizer,
 | 
			
		||||
        chat_template=None,
 | 
			
		||||
        tools=tools,
 | 
			
		||||
        model_config=model_config,
 | 
			
		||||
    )
 | 
			
		||||
    resolved_chat_template_kwargs = resolve_chat_template_kwargs(
 | 
			
		||||
        tokenizer,
 | 
			
		||||
        chat_template=chat_template,
 | 
			
		||||
        chat_template_kwargs=chat_template_kwargs,
 | 
			
		||||
    )
 | 
			
		||||
    assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NOTE: Qwen2-Audio default chat template is specially defined inside
 | 
			
		||||
# processor class instead of using `tokenizer_config.json`
 | 
			
		||||
# yapf: disable
 | 
			
		||||
 | 
			
		||||
@ -593,6 +593,119 @@ def test_concat_and_cache_mla(
 | 
			
		||||
        torch.testing.assert_close(kv_cache, ref_kv_cache)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
 | 
			
		||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
 | 
			
		||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
 | 
			
		||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
 | 
			
		||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
 | 
			
		||||
@pytest.mark.parametrize("dtype", DTYPES)
 | 
			
		||||
@pytest.mark.parametrize("seed", SEEDS)
 | 
			
		||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
 | 
			
		||||
@torch.inference_mode()
 | 
			
		||||
def test_concat_and_cache_ds_mla(
 | 
			
		||||
    kv_lora_rank: int,
 | 
			
		||||
    qk_rope_head_dim: int,
 | 
			
		||||
    num_tokens: int,
 | 
			
		||||
    block_size: int,
 | 
			
		||||
    num_blocks: int,
 | 
			
		||||
    dtype: torch.dtype,
 | 
			
		||||
    seed: int,
 | 
			
		||||
    device: str,
 | 
			
		||||
) -> None:
 | 
			
		||||
    if dtype.itemsize != 2:
 | 
			
		||||
        pytest.skip("ds_mla only supports 16-bit input")
 | 
			
		||||
    kv_cache_dtype = "fp8_ds_mla"
 | 
			
		||||
    current_platform.seed_everything(seed)
 | 
			
		||||
    torch.set_default_device(device)
 | 
			
		||||
 | 
			
		||||
    total_slots = num_blocks * block_size
 | 
			
		||||
    slot_mapping_lst = random.sample(range(total_slots), num_tokens)
 | 
			
		||||
    slot_mapping = torch.tensor(slot_mapping_lst,
 | 
			
		||||
                                dtype=torch.long,
 | 
			
		||||
                                device=device)
 | 
			
		||||
 | 
			
		||||
    kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
 | 
			
		||||
    k_pe = torch.randn(num_tokens,
 | 
			
		||||
                       qk_rope_head_dim,
 | 
			
		||||
                       dtype=dtype,
 | 
			
		||||
                       device=device)
 | 
			
		||||
    entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
 | 
			
		||||
 | 
			
		||||
    scale = torch.tensor(1.0, dtype=torch.float32, device=device)
 | 
			
		||||
    kv_cache = _create_mla_cache(num_blocks,
 | 
			
		||||
                                 block_size,
 | 
			
		||||
                                 entry_size,
 | 
			
		||||
                                 dtype=torch.uint8,
 | 
			
		||||
                                 kv_cache_dtype=kv_cache_dtype,
 | 
			
		||||
                                 device=device)
 | 
			
		||||
 | 
			
		||||
    ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
 | 
			
		||||
    tile_data = torch.zeros(128, dtype=dtype, device=device)
 | 
			
		||||
 | 
			
		||||
    for i in range(num_tokens):
 | 
			
		||||
        slot = slot_mapping[i].item()
 | 
			
		||||
        block_idx = slot // block_size
 | 
			
		||||
        block_offset = slot % block_size
 | 
			
		||||
 | 
			
		||||
        ref_cache_slice = ref_cache[block_idx, block_offset]
 | 
			
		||||
        ref_cache_16bit = ref_cache_slice.view(dtype)
 | 
			
		||||
        ref_cache_32bit = ref_cache_slice.view(torch.float32)
 | 
			
		||||
 | 
			
		||||
        kv_c_data = kv_c[i]
 | 
			
		||||
        for tile_idx in range(4):
 | 
			
		||||
            tile_start = tile_idx * 128
 | 
			
		||||
            tile_end = (tile_idx + 1) * 128
 | 
			
		||||
            tile_data[:] = kv_c_data[tile_start:tile_end]
 | 
			
		||||
 | 
			
		||||
            # tile_scale = tile_data.amax().to(torch.float32) / 448.
 | 
			
		||||
            # NOTE: Using torch's amax() gives different results,
 | 
			
		||||
            # so this must be manually computed.
 | 
			
		||||
            tile_data_float = tile_data.to(torch.float32)
 | 
			
		||||
            manual_max = abs(tile_data_float[0])
 | 
			
		||||
            for j in range(1, 128):
 | 
			
		||||
                manual_max = max(manual_max, abs(tile_data_float[j]))
 | 
			
		||||
            tile_scale = manual_max / 448.
 | 
			
		||||
 | 
			
		||||
            ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
 | 
			
		||||
 | 
			
		||||
            ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
 | 
			
		||||
                            tile_data,
 | 
			
		||||
                            tile_scale.item(),
 | 
			
		||||
                            kv_dtype="fp8")
 | 
			
		||||
 | 
			
		||||
        for j in range(qk_rope_head_dim):
 | 
			
		||||
            ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
 | 
			
		||||
 | 
			
		||||
    opcheck(
 | 
			
		||||
        torch.ops._C_cache_ops.concat_and_cache_mla,
 | 
			
		||||
        (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
 | 
			
		||||
        test_utils=DEFAULT_OPCHECK_TEST_UTILS,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
 | 
			
		||||
                             kv_cache_dtype, scale)
 | 
			
		||||
 | 
			
		||||
    for i in range(num_tokens):
 | 
			
		||||
        slot = slot_mapping[i].item()
 | 
			
		||||
        block_idx = slot // block_size
 | 
			
		||||
        block_offset = slot % block_size
 | 
			
		||||
        kv_cache_slice = kv_cache[block_idx, block_offset]
 | 
			
		||||
        ref_cache_slice = ref_cache[block_idx, block_offset]
 | 
			
		||||
 | 
			
		||||
        kv_nope = kv_cache_slice[:kv_lora_rank]
 | 
			
		||||
        ref_nope = ref_cache_slice[:kv_lora_rank]
 | 
			
		||||
        kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
 | 
			
		||||
                                                       4:kv_lora_rank // 4 + 4]
 | 
			
		||||
        ref_scales = ref_cache_slice.view(
 | 
			
		||||
            torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
 | 
			
		||||
        kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
 | 
			
		||||
        ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
 | 
			
		||||
 | 
			
		||||
        torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
 | 
			
		||||
        torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
 | 
			
		||||
        torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
 | 
			
		||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
 | 
			
		||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										279
									
								
								tests/kernels/attention/test_deepgemm_attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										279
									
								
								tests/kernels/attention/test_deepgemm_attention.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,279 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from vllm.platforms import current_platform
 | 
			
		||||
from vllm.utils import cdiv, has_deep_gemm
 | 
			
		||||
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
 | 
			
		||||
                                  fp8_paged_mqa_logits, get_num_sms,
 | 
			
		||||
                                  get_paged_mqa_logits_metadata)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    # x: (num_blocks, block_size, 1, head_dim)
 | 
			
		||||
    num_blocks, block_size, num_heads, head_dim = x.shape
 | 
			
		||||
    assert num_heads == 1
 | 
			
		||||
    x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
 | 
			
		||||
    sf = x_amax / 448.0
 | 
			
		||||
    x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
 | 
			
		||||
    x_fp8 = torch.empty(
 | 
			
		||||
        (num_blocks, block_size * (head_dim + 4)),
 | 
			
		||||
        device=x.device,
 | 
			
		||||
        dtype=torch.uint8,
 | 
			
		||||
    )
 | 
			
		||||
    x_fp8[:, :block_size * head_dim] = x_scaled.view(
 | 
			
		||||
        num_blocks, block_size * head_dim).view(dtype=torch.uint8)
 | 
			
		||||
    x_fp8[:,
 | 
			
		||||
          block_size * head_dim:] = sf.view(num_blocks,
 | 
			
		||||
                                            block_size).view(dtype=torch.uint8)
 | 
			
		||||
    return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def per_custom_dims_cast_to_fp8(
 | 
			
		||||
        x: torch.Tensor, dims: tuple,
 | 
			
		||||
        use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
 | 
			
		||||
    x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
 | 
			
		||||
    sf = x_amax / 448.0
 | 
			
		||||
    sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
 | 
			
		||||
    x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
 | 
			
		||||
    return x_scaled, sf.squeeze()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
 | 
			
		||||
    assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
 | 
			
		||||
    chunk_size = seq_len // 2
 | 
			
		||||
    cp_size = seq_len_kv // seq_len
 | 
			
		||||
    cp_id = cp_size // 3
 | 
			
		||||
    ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
 | 
			
		||||
    ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
 | 
			
		||||
    for i in range(chunk_size):
 | 
			
		||||
        ke[i] = cp_id * chunk_size + i
 | 
			
		||||
        ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
 | 
			
		||||
    return ks, ke
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ref_fp8_mqa_logits(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    kv: torch.Tensor,
 | 
			
		||||
    weights: torch.Tensor,
 | 
			
		||||
    cu_seqlen_ks: torch.Tensor,
 | 
			
		||||
    cu_seqlen_ke: torch.Tensor,
 | 
			
		||||
):
 | 
			
		||||
    seq_len_kv = kv.shape[0]
 | 
			
		||||
 | 
			
		||||
    k = kv
 | 
			
		||||
    q = q.float()
 | 
			
		||||
    k = k.float()
 | 
			
		||||
 | 
			
		||||
    mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
 | 
			
		||||
               >= cu_seqlen_ks[:, None])
 | 
			
		||||
    mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
 | 
			
		||||
               < cu_seqlen_ke[:, None])
 | 
			
		||||
    mask = mask_lo & mask_hi
 | 
			
		||||
 | 
			
		||||
    score = torch.einsum("mhd,and->hmn", q, k)
 | 
			
		||||
    logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
 | 
			
		||||
    logits = logits.masked_fill(~mask, float("-inf"))
 | 
			
		||||
 | 
			
		||||
    return logits
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
 | 
			
		||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
 | 
			
		||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
 | 
			
		||||
                    reason="SM90 and SM100 only")
 | 
			
		||||
def test_deepgemm_fp8_mqa_logits():
 | 
			
		||||
    torch.manual_seed(0)
 | 
			
		||||
    random.seed(0)
 | 
			
		||||
    num_heads, head_dim = 32, 128
 | 
			
		||||
    for seq_len in (512, ):
 | 
			
		||||
        for seq_len_kv in (1024, ):
 | 
			
		||||
            for disable_cp in (False, True):
 | 
			
		||||
                q = torch.randn(
 | 
			
		||||
                    seq_len,
 | 
			
		||||
                    num_heads,
 | 
			
		||||
                    head_dim,
 | 
			
		||||
                    device="cuda",
 | 
			
		||||
                    dtype=torch.bfloat16,
 | 
			
		||||
                )
 | 
			
		||||
                kv = torch.randn(seq_len_kv,
 | 
			
		||||
                                 head_dim,
 | 
			
		||||
                                 device="cuda",
 | 
			
		||||
                                 dtype=torch.bfloat16)
 | 
			
		||||
                weights = torch.randn(seq_len,
 | 
			
		||||
                                      num_heads,
 | 
			
		||||
                                      device="cuda",
 | 
			
		||||
                                      dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
                if disable_cp:
 | 
			
		||||
                    ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
 | 
			
		||||
                    ke = torch.arange(seq_len, dtype=torch.int,
 | 
			
		||||
                                      device="cuda") + (seq_len_kv - seq_len)
 | 
			
		||||
                else:
 | 
			
		||||
                    ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
 | 
			
		||||
 | 
			
		||||
                q_fp8 = q.to(torch.float8_e4m3fn)
 | 
			
		||||
                kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
 | 
			
		||||
                logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
 | 
			
		||||
 | 
			
		||||
                ref_logits = _ref_fp8_mqa_logits(
 | 
			
		||||
                    q=q,
 | 
			
		||||
                    kv=kv,
 | 
			
		||||
                    weights=weights,
 | 
			
		||||
                    cu_seqlen_ks=ks,
 | 
			
		||||
                    cu_seqlen_ke=ke,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                ref_neginf_mask = ref_logits == float("-inf")
 | 
			
		||||
                neginf_mask = logits == float("-inf")
 | 
			
		||||
                assert torch.equal(neginf_mask, ref_neginf_mask)
 | 
			
		||||
 | 
			
		||||
                ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
 | 
			
		||||
                logits = logits.masked_fill(neginf_mask, 0)
 | 
			
		||||
                diff = calc_diff(logits, ref_logits)
 | 
			
		||||
                assert diff < 1e-3, f"{diff=}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _ref_fp8_paged_mqa_logits(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    kv_cache: torch.Tensor,
 | 
			
		||||
    weights: torch.Tensor,
 | 
			
		||||
    context_lens: torch.Tensor,
 | 
			
		||||
    block_tables: torch.Tensor,
 | 
			
		||||
    max_model_len: int,
 | 
			
		||||
):
 | 
			
		||||
    batch_size, next_n, _, _ = q.size()
 | 
			
		||||
    _, block_size, _, _ = kv_cache.size()
 | 
			
		||||
    logits = torch.full(
 | 
			
		||||
        [batch_size * next_n, max_model_len],
 | 
			
		||||
        float("-inf"),
 | 
			
		||||
        device=q.device,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
    )
 | 
			
		||||
    context_lens_list = context_lens.tolist()
 | 
			
		||||
    for i in range(batch_size):
 | 
			
		||||
        context_len = context_lens_list[i]
 | 
			
		||||
        q_offsets = torch.arange(context_len - next_n,
 | 
			
		||||
                                 context_len,
 | 
			
		||||
                                 device="cuda")
 | 
			
		||||
        weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
 | 
			
		||||
            0, 1).contiguous())
 | 
			
		||||
        for block_rk in range(cdiv(context_len, block_size)):
 | 
			
		||||
            block_idx = block_tables[i][block_rk]
 | 
			
		||||
            qx, kx = q[i], kv_cache[block_idx]
 | 
			
		||||
            k_offsets = torch.arange(
 | 
			
		||||
                block_rk * block_size,
 | 
			
		||||
                (block_rk + 1) * block_size,
 | 
			
		||||
                device="cuda",
 | 
			
		||||
            )
 | 
			
		||||
            mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
 | 
			
		||||
                                                         <= q_offsets[:, None])
 | 
			
		||||
            s = torch.where(
 | 
			
		||||
                mask[None, :, :],
 | 
			
		||||
                (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
 | 
			
		||||
                    logits.dtype),
 | 
			
		||||
                float("-inf"),
 | 
			
		||||
            )
 | 
			
		||||
            s = torch.relu(s) * weight_slice[..., None]
 | 
			
		||||
            s = s.sum(dim=0)
 | 
			
		||||
            logits[
 | 
			
		||||
                i * next_n:(i + 1) * next_n,
 | 
			
		||||
                block_rk * block_size:(block_rk + 1) * block_size,
 | 
			
		||||
            ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
 | 
			
		||||
                            float("-inf"))
 | 
			
		||||
    return logits
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
 | 
			
		||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
 | 
			
		||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
 | 
			
		||||
                    reason="SM90 and SM100 only")
 | 
			
		||||
def test_deepgemm_fp8_paged_mqa_logits():
 | 
			
		||||
    torch.manual_seed(0)
 | 
			
		||||
    random.seed(0)
 | 
			
		||||
 | 
			
		||||
    max_model_len = 4096
 | 
			
		||||
    for batch_size, next_n in [(4, 1), (2, 2)]:
 | 
			
		||||
        for heads, index_dim in [(32, 128)]:
 | 
			
		||||
            for avg_kv in (2048, ):
 | 
			
		||||
                num_blocks, blocksize = max_model_len * 2, 64
 | 
			
		||||
 | 
			
		||||
                q = torch.randn(
 | 
			
		||||
                    (batch_size, next_n, heads, index_dim),
 | 
			
		||||
                    device="cuda",
 | 
			
		||||
                    dtype=torch.bfloat16,
 | 
			
		||||
                )
 | 
			
		||||
                kv_cache = torch.randn(
 | 
			
		||||
                    (num_blocks, blocksize, 1, index_dim),
 | 
			
		||||
                    device="cuda",
 | 
			
		||||
                    dtype=torch.bfloat16,
 | 
			
		||||
                )
 | 
			
		||||
                weights = torch.randn(
 | 
			
		||||
                    (batch_size * next_n, heads),
 | 
			
		||||
                    device="cuda",
 | 
			
		||||
                    dtype=torch.float32,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                context_lens = (torch.randint(int(0.8 * avg_kv),
 | 
			
		||||
                                              int(1.2 * avg_kv),
 | 
			
		||||
                                              (batch_size, )).cuda().to(
 | 
			
		||||
                                                  torch.int32))
 | 
			
		||||
                max_block_len = ((context_lens.max().item() + blocksize - 1) //
 | 
			
		||||
                                 blocksize * blocksize)
 | 
			
		||||
                block_tables = torch.zeros(
 | 
			
		||||
                    (batch_size, max_block_len),
 | 
			
		||||
                    device="cuda",
 | 
			
		||||
                    dtype=torch.int32,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                counter = 0
 | 
			
		||||
                block_idx_pool = list(range(num_blocks))
 | 
			
		||||
                random.shuffle(block_idx_pool)
 | 
			
		||||
                for i in range(batch_size):
 | 
			
		||||
                    ctx_len = int(context_lens[i].item())
 | 
			
		||||
                    for j in range((ctx_len + blocksize - 1) // blocksize):
 | 
			
		||||
                        block_tables[i][j] = block_idx_pool[counter]
 | 
			
		||||
                        counter += 1
 | 
			
		||||
 | 
			
		||||
                q_fp8 = q.to(torch.float8_e4m3fn)
 | 
			
		||||
                kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
 | 
			
		||||
 | 
			
		||||
                schedule_metadata = get_paged_mqa_logits_metadata(
 | 
			
		||||
                    context_lens, blocksize, get_num_sms())
 | 
			
		||||
                logits = fp8_paged_mqa_logits(
 | 
			
		||||
                    q_fp8,
 | 
			
		||||
                    kv_cache_fp8,
 | 
			
		||||
                    weights,
 | 
			
		||||
                    context_lens,
 | 
			
		||||
                    block_tables,
 | 
			
		||||
                    schedule_metadata,
 | 
			
		||||
                    max_model_len,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                ref_logits = _ref_fp8_paged_mqa_logits(
 | 
			
		||||
                    q,
 | 
			
		||||
                    kv_cache,
 | 
			
		||||
                    weights,
 | 
			
		||||
                    context_lens,
 | 
			
		||||
                    block_tables,
 | 
			
		||||
                    max_model_len,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                positions = (torch.arange(max_model_len,
 | 
			
		||||
                                          device="cuda").unsqueeze(0).expand(
 | 
			
		||||
                                              batch_size * next_n, -1))
 | 
			
		||||
                row_indices = (
 | 
			
		||||
                    torch.arange(batch_size * next_n, device="cuda") // next_n)
 | 
			
		||||
                next_n_offset = (
 | 
			
		||||
                    torch.arange(batch_size * next_n, device="cuda") % next_n)
 | 
			
		||||
                mask = positions <= (context_lens[row_indices] - next_n +
 | 
			
		||||
                                     next_n_offset).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
                logits = logits.masked_fill(~mask, 0)
 | 
			
		||||
                ref_logits = ref_logits.masked_fill(~mask, 0)
 | 
			
		||||
                diff = calc_diff(logits, ref_logits)
 | 
			
		||||
                assert diff < 1e-3, f"{diff=}"
 | 
			
		||||
@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
 | 
			
		||||
        descale_k = None
 | 
			
		||||
 | 
			
		||||
    def flash_mla():
 | 
			
		||||
        return flash_mla_with_kvcache(
 | 
			
		||||
            q,
 | 
			
		||||
            blocked_k,
 | 
			
		||||
            block_table,
 | 
			
		||||
            cache_seqlens,
 | 
			
		||||
            dv,
 | 
			
		||||
            tile_scheduler_metadata,
 | 
			
		||||
            num_splits,
 | 
			
		||||
            causal=causal,
 | 
			
		||||
            descale_q=descale_q,
 | 
			
		||||
            descale_k=descale_k,
 | 
			
		||||
        )
 | 
			
		||||
        return flash_mla_with_kvcache(q,
 | 
			
		||||
                                      blocked_k,
 | 
			
		||||
                                      block_table,
 | 
			
		||||
                                      cache_seqlens,
 | 
			
		||||
                                      dv,
 | 
			
		||||
                                      tile_scheduler_metadata,
 | 
			
		||||
                                      num_splits,
 | 
			
		||||
                                      causal=causal,
 | 
			
		||||
                                      descale_q=descale_q,
 | 
			
		||||
                                      descale_k=descale_k)
 | 
			
		||||
 | 
			
		||||
    def scaled_dot_product_attention(query, key, value, is_causal=False):
 | 
			
		||||
        query = query.float()
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										119
									
								
								tests/kernels/attention/test_flashmla_sparse.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								tests/kernels/attention/test_flashmla_sparse.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,119 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cuda_sm90_available() -> bool:
 | 
			
		||||
    if not torch.cuda.is_available():
 | 
			
		||||
        return False
 | 
			
		||||
    major, _ = torch.cuda.get_device_capability()
 | 
			
		||||
    return major == 9
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sparse_flashmla_metadata_smoke():
 | 
			
		||||
    import vllm.attention.ops.flashmla as fm
 | 
			
		||||
    ok, reason = fm.is_flashmla_supported()
 | 
			
		||||
    if not ok or not _cuda_sm90_available():
 | 
			
		||||
        pytest.skip(reason or "SM90 not available")
 | 
			
		||||
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
    batch_size = 1
 | 
			
		||||
    seqlen_q = 1
 | 
			
		||||
    num_heads_q = 128
 | 
			
		||||
    num_heads_k = 1
 | 
			
		||||
    q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
 | 
			
		||||
    topk = 128
 | 
			
		||||
 | 
			
		||||
    cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
 | 
			
		||||
 | 
			
		||||
    tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
 | 
			
		||||
                                              q_seq_per_hk,
 | 
			
		||||
                                              num_heads_k,
 | 
			
		||||
                                              num_heads_q=num_heads_q,
 | 
			
		||||
                                              topk=topk,
 | 
			
		||||
                                              is_fp8_kvcache=True)
 | 
			
		||||
    assert tile_md.dtype == torch.int32
 | 
			
		||||
    assert num_splits.dtype == torch.int32
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sparse_flashmla_decode_smoke():
 | 
			
		||||
    import vllm.attention.ops.flashmla as fm
 | 
			
		||||
    ok, reason = fm.is_flashmla_supported()
 | 
			
		||||
    if not ok or not _cuda_sm90_available():
 | 
			
		||||
        pytest.skip(reason or "SM90 not available")
 | 
			
		||||
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
    batch_size = 1
 | 
			
		||||
    seqlen_q = 1
 | 
			
		||||
    num_heads_q = 1
 | 
			
		||||
    head_dim_k = 576
 | 
			
		||||
    head_dim_v = 512
 | 
			
		||||
    num_heads_k = 1
 | 
			
		||||
    page_block_size = 64
 | 
			
		||||
    bytes_per_token = 656
 | 
			
		||||
    topk = 128
 | 
			
		||||
 | 
			
		||||
    # Metadata
 | 
			
		||||
    q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
 | 
			
		||||
    # q_heads_per_hk = num_heads_q // num_heads_k
 | 
			
		||||
    cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
 | 
			
		||||
    tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
 | 
			
		||||
                                              q_seq_per_hk,
 | 
			
		||||
                                              num_heads_k,
 | 
			
		||||
                                              num_heads_q=num_heads_q,
 | 
			
		||||
                                              topk=topk,
 | 
			
		||||
                                              is_fp8_kvcache=True)
 | 
			
		||||
 | 
			
		||||
    # Inputs
 | 
			
		||||
    q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
 | 
			
		||||
                    dtype=torch.bfloat16,
 | 
			
		||||
                    device=device)
 | 
			
		||||
    k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
 | 
			
		||||
                          dtype=torch.uint8,
 | 
			
		||||
                          device=device)
 | 
			
		||||
    indices = torch.zeros((batch_size, seqlen_q, topk),
 | 
			
		||||
                          dtype=torch.int32,
 | 
			
		||||
                          device=device)
 | 
			
		||||
 | 
			
		||||
    block_table = torch.zeros((batch_size, 128),
 | 
			
		||||
                              dtype=torch.int32,
 | 
			
		||||
                              device=device)
 | 
			
		||||
    out, lse = fm.flash_mla_with_kvcache(q,
 | 
			
		||||
                                         k_cache,
 | 
			
		||||
                                         block_table,
 | 
			
		||||
                                         cache_seqlens,
 | 
			
		||||
                                         head_dim_v,
 | 
			
		||||
                                         tile_md,
 | 
			
		||||
                                         num_splits,
 | 
			
		||||
                                         indices=indices,
 | 
			
		||||
                                         is_fp8_kvcache=True)
 | 
			
		||||
    assert out.shape[0] == batch_size
 | 
			
		||||
    assert out.shape[-1] == head_dim_v
 | 
			
		||||
    assert lse.shape[0] == batch_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sparse_flashmla_prefill_smoke():
 | 
			
		||||
    import vllm.attention.ops.flashmla as fm
 | 
			
		||||
    ok, reason = fm.is_flashmla_supported()
 | 
			
		||||
    if not ok or not _cuda_sm90_available():
 | 
			
		||||
        pytest.skip(reason or "SM90 not available")
 | 
			
		||||
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
    s_q = 1
 | 
			
		||||
    s_kv = 1
 | 
			
		||||
    h_q = 64  # kernel expects multiple of 64
 | 
			
		||||
    h_kv = 1
 | 
			
		||||
    d_qk = 576
 | 
			
		||||
    d_v = 512
 | 
			
		||||
    topk = 128
 | 
			
		||||
 | 
			
		||||
    q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
 | 
			
		||||
    kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
 | 
			
		||||
    indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
 | 
			
		||||
 | 
			
		||||
    out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
 | 
			
		||||
                                                       d_v)
 | 
			
		||||
    assert out.shape == (s_q, h_q, d_v)
 | 
			
		||||
    assert max_logits.shape == (s_q, h_q)
 | 
			
		||||
    assert lse.shape == (s_q, h_q)
 | 
			
		||||
							
								
								
									
										245
									
								
								tests/kernels/attention/test_pack_unpack_triton.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										245
									
								
								tests/kernels/attention/test_pack_unpack_triton.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,245 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.testing import assert_close
 | 
			
		||||
 | 
			
		||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_pack_seq_basic_fp8():
 | 
			
		||||
    """Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
 | 
			
		||||
    # Test cases with 3D tensors (N, H, D)
 | 
			
		||||
    test_cases = [
 | 
			
		||||
        (6, 8, 4, 2, [3, 3]),  # (6, 8, 4) -> (2, 3, 8, 4)
 | 
			
		||||
        (10, 4, 8, 3, [2, 4, 4]),  # (10, 4, 8) -> (3, 4, 4, 8)
 | 
			
		||||
        (20, 16, 32, 4, [5, 5, 5, 5]),  # (20, 16, 32) -> (4, 5, 16, 32)
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    for N, H, D, B, lengths_list in test_cases:
 | 
			
		||||
        # Create input tensor with small values for fp8
 | 
			
		||||
        x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
        x = x.to(dtype=dtype)
 | 
			
		||||
        lengths = torch.tensor(lengths_list, device=device)
 | 
			
		||||
 | 
			
		||||
        # Pack the data
 | 
			
		||||
        packed = pack_seq_triton(x, lengths)
 | 
			
		||||
 | 
			
		||||
        # Check output shape and properties
 | 
			
		||||
        expected_shape = (B, max(lengths_list), H, D)
 | 
			
		||||
        assert packed.shape == expected_shape
 | 
			
		||||
        assert packed.dtype == dtype
 | 
			
		||||
        assert packed.device == x.device
 | 
			
		||||
 | 
			
		||||
        # Check that valid data is preserved (within fp8 precision)
 | 
			
		||||
        for b in range(B):
 | 
			
		||||
            start_idx = sum(lengths_list[:b])
 | 
			
		||||
            seq_len = lengths_list[b]
 | 
			
		||||
 | 
			
		||||
            expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
 | 
			
		||||
            actual_data = packed[b, :seq_len].to(torch.float32)
 | 
			
		||||
 | 
			
		||||
            assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_pack_seq_custom_padding_fp8():
 | 
			
		||||
    """Test pack_seq_triton with custom padding values for fp8."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
    N, H, D, B = 20, 8, 16, 2
 | 
			
		||||
    lengths = torch.tensor([10, 10], device=device)
 | 
			
		||||
 | 
			
		||||
    x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    # Test with different padding values
 | 
			
		||||
    for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
 | 
			
		||||
        result = pack_seq_triton(x, lengths, pad_value=pad_value)
 | 
			
		||||
 | 
			
		||||
        # Check valid data
 | 
			
		||||
        for b in range(B):
 | 
			
		||||
            start_idx = b * 10
 | 
			
		||||
            expected_data = x[start_idx:start_idx + 10].to(torch.float32)
 | 
			
		||||
            actual_data = result[b, :10].to(torch.float32)
 | 
			
		||||
            assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
 | 
			
		||||
 | 
			
		||||
        # Check padding (fp8 has limited range, so check for large values)
 | 
			
		||||
        padded_data = result[:, 10:].to(torch.float32)
 | 
			
		||||
        if pad_value < 0:
 | 
			
		||||
            assert torch.all(padded_data < -50)  # Large negative values
 | 
			
		||||
        elif pad_value > 0:
 | 
			
		||||
            assert torch.all(padded_data > 50)  # Large positive values
 | 
			
		||||
        else:
 | 
			
		||||
            assert torch.allclose(padded_data,
 | 
			
		||||
                                  torch.zeros_like(padded_data),
 | 
			
		||||
                                  atol=1e-2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_pack_seq_default_negative_inf_padding_fp8():
 | 
			
		||||
    """Test that pack_seq_triton uses -inf padding by default for fp8."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
    # B = 2
 | 
			
		||||
    N, H, D = 20, 8, 16
 | 
			
		||||
    lengths = torch.tensor([10, 10], device=device)
 | 
			
		||||
 | 
			
		||||
    x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
    result = pack_seq_triton(x, lengths)
 | 
			
		||||
 | 
			
		||||
    # Check that padding is large negative values (fp8 representation of -inf)
 | 
			
		||||
    padded_data = result[:, 10:].to(torch.float32)
 | 
			
		||||
    assert torch.all(
 | 
			
		||||
        padded_data < -100)  # fp8 -inf is represented as large negative number
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_pack_seq_edge_cases_fp8():
 | 
			
		||||
    """Test pack_seq_triton with edge cases for fp8."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
 | 
			
		||||
    # Test with single batch element
 | 
			
		||||
    x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
    lengths = torch.tensor([10], device=device)
 | 
			
		||||
    result = pack_seq_triton(x, lengths)
 | 
			
		||||
    assert result.shape == (1, 10, 8, 16)
 | 
			
		||||
 | 
			
		||||
    # Test with very short sequences
 | 
			
		||||
    x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
    lengths = torch.tensor([1, 1, 1], device=device)
 | 
			
		||||
    result = pack_seq_triton(x, lengths)
 | 
			
		||||
    assert result.shape == (3, 1, 4, 8)
 | 
			
		||||
 | 
			
		||||
    # Test with different sequence lengths
 | 
			
		||||
    x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
    lengths = torch.tensor([5, 7, 3], device=device)
 | 
			
		||||
    result = pack_seq_triton(x, lengths)
 | 
			
		||||
    assert result.shape == (3, 7, 8, 16)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_pack_seq_different_block_sizes_fp8():
 | 
			
		||||
    """Test pack_seq_triton with different block sizes for fp8."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
    N, H, D, B = 100, 16, 32, 4
 | 
			
		||||
    lengths = torch.tensor([25, 25, 25, 25], device=device)
 | 
			
		||||
 | 
			
		||||
    x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    # Test different block sizes
 | 
			
		||||
    for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
 | 
			
		||||
        result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
 | 
			
		||||
 | 
			
		||||
        assert result.shape == (B, 25, H, D)
 | 
			
		||||
 | 
			
		||||
        # Check that valid data is preserved (within fp8 precision)
 | 
			
		||||
        for b in range(B):
 | 
			
		||||
            start_idx = b * 25
 | 
			
		||||
            expected_data = x[start_idx:start_idx + 25].to(torch.float32)
 | 
			
		||||
            actual_data = result[b, :25].to(torch.float32)
 | 
			
		||||
            assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_pack_seq_shape_consistency():
 | 
			
		||||
    """Test that pack_seq_triton maintains shape consistency."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
    N, H, D, B = 20, 8, 16, 2
 | 
			
		||||
    lengths = torch.tensor([10, 10], device=device)
 | 
			
		||||
 | 
			
		||||
    x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
 | 
			
		||||
    result = pack_seq_triton(x, lengths)
 | 
			
		||||
 | 
			
		||||
    # Check shape consistency
 | 
			
		||||
    assert result.shape[0] == B  # Batch dimension
 | 
			
		||||
    assert result.shape[1] == lengths.max().item()  # Max sequence length
 | 
			
		||||
    assert result.shape[2:] == x.shape[1:]  # Feature dimensions preserved
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_pack_unpack_roundtrip_fp8():
 | 
			
		||||
    """Test that pack -> unpack gives us back the original data for fp8."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
 | 
			
		||||
    # Test cases with 3D tensors
 | 
			
		||||
    test_cases = [
 | 
			
		||||
        (6, 8, 4, 2, [3, 3]),
 | 
			
		||||
        (10, 4, 8, 3, [2, 4, 4]),
 | 
			
		||||
        (20, 16, 32, 4, [5, 5, 5, 5]),
 | 
			
		||||
        (15, 8, 16, 3, [7, 5, 3]),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    for N, H, D, B, lengths_list in test_cases:
 | 
			
		||||
        # Create input tensor with small values for fp8
 | 
			
		||||
        x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
        x = x.to(dtype=dtype)
 | 
			
		||||
        lengths = torch.tensor(lengths_list, device=device)
 | 
			
		||||
 | 
			
		||||
        # Pack the data
 | 
			
		||||
        packed = pack_seq_triton(x, lengths)
 | 
			
		||||
 | 
			
		||||
        # Unpack the data
 | 
			
		||||
        unpacked = unpack_seq_triton(packed, lengths)
 | 
			
		||||
 | 
			
		||||
        # Check that we get back the original data (within fp8 precision)
 | 
			
		||||
        assert unpacked.shape == x.shape
 | 
			
		||||
        x_f32 = x.to(torch.float32)
 | 
			
		||||
        unpacked_f32 = unpacked.to(torch.float32)
 | 
			
		||||
        assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
 | 
			
		||||
 | 
			
		||||
        # Unpack without explicit start locations (computed in kernel)
 | 
			
		||||
        unpacked_with_loc = unpack_seq_triton(packed, lengths)
 | 
			
		||||
        assert_close(x_f32,
 | 
			
		||||
                     unpacked_with_loc.to(torch.float32),
 | 
			
		||||
                     rtol=1e-3,
 | 
			
		||||
                     atol=1e-2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_unpack_seq_triton_edge_cases_fp8():
 | 
			
		||||
    """Test unpack function with edge cases for fp8."""
 | 
			
		||||
    device = "cuda"
 | 
			
		||||
    dtype = torch.float8_e4m3fn
 | 
			
		||||
 | 
			
		||||
    # Test with single batch element
 | 
			
		||||
    x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
    lengths = torch.tensor([10], device=device)
 | 
			
		||||
    packed = pack_seq_triton(x, lengths)
 | 
			
		||||
    unpacked = unpack_seq_triton(packed, lengths)
 | 
			
		||||
    assert unpacked.shape == x.shape
 | 
			
		||||
    assert_close(x.to(torch.float32),
 | 
			
		||||
                 unpacked.to(torch.float32),
 | 
			
		||||
                 rtol=1e-1,
 | 
			
		||||
                 atol=1e-2)
 | 
			
		||||
 | 
			
		||||
    # Test with very short sequences
 | 
			
		||||
    x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
    lengths = torch.tensor([1, 1, 1], device=device)
 | 
			
		||||
    packed = pack_seq_triton(x, lengths)
 | 
			
		||||
    unpacked = unpack_seq_triton(packed, lengths)
 | 
			
		||||
    # Only compare the first 3 elements that were actually packed
 | 
			
		||||
    assert_close(x[:3].to(torch.float32),
 | 
			
		||||
                 unpacked.to(torch.float32),
 | 
			
		||||
                 rtol=1e-1,
 | 
			
		||||
                 atol=1e-2)
 | 
			
		||||
 | 
			
		||||
    x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
 | 
			
		||||
    x = x.to(dtype=dtype)
 | 
			
		||||
    lengths = torch.tensor([5, 7, 3], device=device)
 | 
			
		||||
    packed = pack_seq_triton(x, lengths)
 | 
			
		||||
    unpacked = unpack_seq_triton(packed, lengths)
 | 
			
		||||
    assert unpacked.shape == x.shape
 | 
			
		||||
    assert_close(x.to(torch.float32),
 | 
			
		||||
                 unpacked.to(torch.float32),
 | 
			
		||||
                 rtol=1e-1,
 | 
			
		||||
                 atol=1e-2)
 | 
			
		||||
@ -207,6 +207,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
 | 
			
		||||
                                         trust_remote_code=True),
 | 
			
		||||
    "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3",  # noqa: E501
 | 
			
		||||
                                         trust_remote_code=True),
 | 
			
		||||
    "DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
 | 
			
		||||
    "Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
 | 
			
		||||
                                            min_transformers_version="4.54"),
 | 
			
		||||
    "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,8 @@ import pytest
 | 
			
		||||
 | 
			
		||||
from vllm import LLM
 | 
			
		||||
from vllm.utils import GiB_bytes
 | 
			
		||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
 | 
			
		||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
 | 
			
		||||
                                         get_kv_cache_configs)
 | 
			
		||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
 | 
			
		||||
 | 
			
		||||
from ..utils import create_new_process_for_each_test
 | 
			
		||||
@ -62,11 +63,13 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
 | 
			
		||||
    # Avoid calling model.forward()
 | 
			
		||||
    def _initialize_kv_caches_v1(self, vllm_config):
 | 
			
		||||
        kv_cache_specs = self.model_executor.get_kv_cache_specs()
 | 
			
		||||
        scheduler_kv_cache_config = get_kv_cache_configs(
 | 
			
		||||
        kv_cache_configs = get_kv_cache_configs(
 | 
			
		||||
            vllm_config,
 | 
			
		||||
            kv_cache_specs,
 | 
			
		||||
            [10 * GiB_bytes],
 | 
			
		||||
        )[0]
 | 
			
		||||
        )
 | 
			
		||||
        scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
 | 
			
		||||
            kv_cache_configs)
 | 
			
		||||
 | 
			
		||||
        # gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
 | 
			
		||||
        return 1, 0, scheduler_kv_cache_config
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
 | 
			
		||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
 | 
			
		||||
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
 | 
			
		||||
                                  raw_image_url: str, suffix: str):
 | 
			
		||||
    connector = MediaConnector()
 | 
			
		||||
    connector = MediaConnector(
 | 
			
		||||
        # Domain restriction should not apply to data URLs.
 | 
			
		||||
        allowed_media_domains=[
 | 
			
		||||
            "www.bogotobogo.com",
 | 
			
		||||
            "github.com",
 | 
			
		||||
        ])
 | 
			
		||||
    url_image = url_images[raw_image_url]
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
 | 
			
		||||
    modality_idxs = argsort_mm_positions(mm_positions)
 | 
			
		||||
 | 
			
		||||
    assert modality_idxs == expected_modality_idxs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
 | 
			
		||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
 | 
			
		||||
async def test_allowed_media_domains(video_url: str, num_frames: int):
 | 
			
		||||
    connector = MediaConnector(
 | 
			
		||||
        media_io_kwargs={"video": {
 | 
			
		||||
            "num_frames": num_frames,
 | 
			
		||||
        }},
 | 
			
		||||
        allowed_media_domains=[
 | 
			
		||||
            "www.bogotobogo.com",
 | 
			
		||||
            "github.com",
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    video_sync, metadata_sync = connector.fetch_video(video_url)
 | 
			
		||||
    video_async, metadata_async = await connector.fetch_video_async(video_url)
 | 
			
		||||
    assert np.array_equal(video_sync, video_async)
 | 
			
		||||
    assert metadata_sync == metadata_async
 | 
			
		||||
 | 
			
		||||
    disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        _, _ = connector.fetch_video(disallowed_url)
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        _, _ = await connector.fetch_video_async(disallowed_url)
 | 
			
		||||
 | 
			
		||||
@ -26,5 +26,5 @@ class DummyPlatform(Platform):
 | 
			
		||||
 | 
			
		||||
    def get_attn_backend_cls(self, backend_name, head_size, dtype,
 | 
			
		||||
                             kv_cache_dtype, block_size, use_v1, use_mla,
 | 
			
		||||
                             has_sink):
 | 
			
		||||
                             has_sink, use_sparse):
 | 
			
		||||
        return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend"  # noqa E501
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | 
			
		||||
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
@ -10,6 +11,7 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
 | 
			
		||||
                                      create_standard_kv_cache_spec,
 | 
			
		||||
                                      create_vllm_config,
 | 
			
		||||
                                      get_attention_backend)
 | 
			
		||||
from vllm import _custom_ops as ops
 | 
			
		||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
 | 
			
		||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
 | 
			
		||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
 | 
			
		||||
@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache(
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        num_blocks: int,
 | 
			
		||||
        common_attn_metadata: CommonAttentionMetadata,
 | 
			
		||||
        randomize_blocks: bool = True) -> torch.Tensor:
 | 
			
		||||
        randomize_blocks: bool = True,
 | 
			
		||||
        kv_cache_dtype: Optional[str] = None,
 | 
			
		||||
        scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
 | 
			
		||||
    """Create and prepopulate an MLA KV cache with context data.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache(
 | 
			
		||||
        common_attn_metadata: Common attention metadata
 | 
			
		||||
        randomize_blocks: Whether to randomly permute blocks 
 | 
			
		||||
                          or use sequential order
 | 
			
		||||
        kv_cache_dtype: Optional kv cache dtype string. When set to
 | 
			
		||||
                        "fp8_ds_mla" the cache is populated using the
 | 
			
		||||
                        fp8 DeepSeek MLA layout via concat_and_cache_mla.
 | 
			
		||||
        scale: Scaling factor forwarded to concat_and_cache_mla when the
 | 
			
		||||
               fp8 cache layout is requested.
 | 
			
		||||
        
 | 
			
		||||
    Returns:
 | 
			
		||||
        MLA KV cache tensor
 | 
			
		||||
@ -105,23 +114,61 @@ def create_and_prepopulate_kv_cache(
 | 
			
		||||
    block_table = common_attn_metadata.block_table_tensor
 | 
			
		||||
    slot_mapping = common_attn_metadata.slot_mapping
 | 
			
		||||
 | 
			
		||||
    # Create MLA KV cache: (num_blocks, block_size, head_size)
 | 
			
		||||
    kv_cache = torch.empty(num_blocks,
 | 
			
		||||
                           block_size,
 | 
			
		||||
                           head_size,
 | 
			
		||||
                           dtype=dtype,
 | 
			
		||||
                           device=device)
 | 
			
		||||
    kv_cache_flat = kv_cache.view(-1, head_size)
 | 
			
		||||
    use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
 | 
			
		||||
 | 
			
		||||
    if use_fp8_ds_mla:
 | 
			
		||||
        if not kv_c_contexts:
 | 
			
		||||
            raise ValueError("kv_c_contexts cannot be empty when using"
 | 
			
		||||
                             " fp8_ds_mla cache dtype")
 | 
			
		||||
        kv_lora_rank = kv_c_contexts[0].shape[-1]
 | 
			
		||||
        rope_dim = k_pe_contexts[0].shape[-1]
 | 
			
		||||
        entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
 | 
			
		||||
        kv_cache = torch.zeros(num_blocks,
 | 
			
		||||
                               block_size,
 | 
			
		||||
                               entry_size,
 | 
			
		||||
                               dtype=torch.uint8,
 | 
			
		||||
                               device=device)
 | 
			
		||||
        scale_tensor = (scale
 | 
			
		||||
                        if isinstance(scale, torch.Tensor) else torch.tensor(
 | 
			
		||||
                            scale, dtype=torch.float32, device=device))
 | 
			
		||||
        scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
 | 
			
		||||
    else:
 | 
			
		||||
        # Create MLA KV cache: (num_blocks, block_size, head_size)
 | 
			
		||||
        kv_cache = torch.empty(num_blocks,
 | 
			
		||||
                               block_size,
 | 
			
		||||
                               head_size,
 | 
			
		||||
                               dtype=dtype,
 | 
			
		||||
                               device=device)
 | 
			
		||||
        kv_cache_flat = kv_cache.view(-1, head_size)
 | 
			
		||||
 | 
			
		||||
    # Populate the cache with the context tokens
 | 
			
		||||
    # Start from block_id=1 since block_id=0 is considered the null block
 | 
			
		||||
    start_block_idx = 1
 | 
			
		||||
    for i in range(batch_size):
 | 
			
		||||
        kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
 | 
			
		||||
        kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
 | 
			
		||||
        context_len = kv_c_context.shape[0]
 | 
			
		||||
        if context_len == 0:
 | 
			
		||||
            start_block_idx += cdiv(int(seq_lens[i]), block_size)
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        start = start_block_idx * block_size
 | 
			
		||||
        end = start + kv_context.shape[0]
 | 
			
		||||
        kv_cache_flat[start:end, ...] = kv_context
 | 
			
		||||
 | 
			
		||||
        if use_fp8_ds_mla:
 | 
			
		||||
            slots = torch.arange(context_len, device=device,
 | 
			
		||||
                                 dtype=torch.long) + start
 | 
			
		||||
            ops.concat_and_cache_mla(
 | 
			
		||||
                kv_c_context,
 | 
			
		||||
                k_pe_context.squeeze(1),
 | 
			
		||||
                kv_cache,
 | 
			
		||||
                slots,
 | 
			
		||||
                kv_cache_dtype="fp8_ds_mla",
 | 
			
		||||
                scale=scale_tensor,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            kv_context = torch.cat(
 | 
			
		||||
                [kv_c_context, k_pe_context.squeeze(1)], dim=-1)
 | 
			
		||||
            end = start + kv_context.shape[0]
 | 
			
		||||
            kv_cache_flat[start:end, ...] = kv_context
 | 
			
		||||
 | 
			
		||||
        # Stay block aligned and allocate enough blocks for the new tokens
 | 
			
		||||
        start_block_idx += cdiv(int(seq_lens[i]), block_size)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										448
									
								
								tests/v1/attention/test_sparse_mla_backends.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										448
									
								
								tests/v1/attention/test_sparse_mla_backends.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,448 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | 
			
		||||
"""Unit tests for the FlashMLA sparse backend utilities."""
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from types import MethodType, SimpleNamespace
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from tests.v1.attention.test_mla_backends import (
 | 
			
		||||
    BATCH_SPECS, BatchSpec, MockAttentionLayer,
 | 
			
		||||
    create_and_prepopulate_kv_cache)
 | 
			
		||||
from tests.v1.attention.utils import (create_common_attn_metadata,
 | 
			
		||||
                                      create_standard_kv_cache_spec,
 | 
			
		||||
                                      create_vllm_config)
 | 
			
		||||
from vllm import _custom_ops as ops
 | 
			
		||||
from vllm.attention.ops import flashmla
 | 
			
		||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
 | 
			
		||||
from vllm.utils import cdiv
 | 
			
		||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
 | 
			
		||||
    FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
 | 
			
		||||
    FlashMLASparseImpl, FlashMLASparseMetadata)
 | 
			
		||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
 | 
			
		||||
 | 
			
		||||
SPARSE_BACKEND_BATCH_SPECS = {
 | 
			
		||||
    name: BATCH_SPECS[name]
 | 
			
		||||
    for name in [
 | 
			
		||||
        "mixed_small",
 | 
			
		||||
        "mixed_medium",
 | 
			
		||||
        "small_prefill",
 | 
			
		||||
        "medium_prefill",
 | 
			
		||||
        "single_prefill",
 | 
			
		||||
    ]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
 | 
			
		||||
                                                          query_lens=[256] * 2)
 | 
			
		||||
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
 | 
			
		||||
    seq_lens=[256] * 2, query_lens=[256] * 2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _dequantize_fp8_ds_mla_entry(
 | 
			
		||||
        cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
 | 
			
		||||
        dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
 | 
			
		||||
 | 
			
		||||
    # The first kv_lora_rank bytes store FP8 latent values with one scale per
 | 
			
		||||
    # 128 element tile written as float32 right after the latent payload.
 | 
			
		||||
    scales = cache_slice.view(torch.float32)[kv_lora_rank //
 | 
			
		||||
                                             4:kv_lora_rank // 4 + 4]
 | 
			
		||||
    latent = torch.empty(kv_lora_rank,
 | 
			
		||||
                         dtype=torch.float16,
 | 
			
		||||
                         device=cache_slice.device)
 | 
			
		||||
    for tile_idx in range(4):
 | 
			
		||||
        tile_start = tile_idx * 128
 | 
			
		||||
        tile_end = tile_start + 128
 | 
			
		||||
        ops.convert_fp8(latent[tile_start:tile_end],
 | 
			
		||||
                        cache_slice[tile_start:tile_end],
 | 
			
		||||
                        float(scales[tile_idx].item()),
 | 
			
		||||
                        kv_dtype="fp8")
 | 
			
		||||
    latent = latent.to(dtype)
 | 
			
		||||
 | 
			
		||||
    rope_offset = kv_lora_rank // 2 + 8
 | 
			
		||||
    rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
 | 
			
		||||
    return latent, rope_vals.clone()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _quantize_dequantize_fp8_ds_mla(
 | 
			
		||||
        kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
 | 
			
		||||
        scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
 | 
			
		||||
 | 
			
		||||
    if kv_c.numel() == 0:
 | 
			
		||||
        return kv_c.clone(), k_pe.clone()
 | 
			
		||||
 | 
			
		||||
    kv_lora_rank = kv_c.shape[-1]
 | 
			
		||||
    rope_dim = k_pe.shape[-1]
 | 
			
		||||
    num_tokens = kv_c.shape[0]
 | 
			
		||||
    num_blocks = max(1, math.ceil(num_tokens / block_size))
 | 
			
		||||
    entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
 | 
			
		||||
 | 
			
		||||
    tmp_cache = torch.zeros(num_blocks,
 | 
			
		||||
                            block_size,
 | 
			
		||||
                            entry_size,
 | 
			
		||||
                            dtype=torch.uint8,
 | 
			
		||||
                            device=kv_c.device)
 | 
			
		||||
    slot_mapping = torch.arange(num_tokens,
 | 
			
		||||
                                dtype=torch.long,
 | 
			
		||||
                                device=kv_c.device)
 | 
			
		||||
 | 
			
		||||
    ops.concat_and_cache_mla(kv_c,
 | 
			
		||||
                             k_pe,
 | 
			
		||||
                             tmp_cache,
 | 
			
		||||
                             slot_mapping,
 | 
			
		||||
                             kv_cache_dtype="fp8_ds_mla",
 | 
			
		||||
                             scale=scale)
 | 
			
		||||
 | 
			
		||||
    dequant_kv_c = torch.empty_like(kv_c)
 | 
			
		||||
    dequant_k_pe = torch.empty_like(k_pe)
 | 
			
		||||
 | 
			
		||||
    for token_idx in range(num_tokens):
 | 
			
		||||
        slot = slot_mapping[token_idx].item()
 | 
			
		||||
        block_idx = slot // block_size
 | 
			
		||||
        block_offset = slot % block_size
 | 
			
		||||
        cache_slice = tmp_cache[block_idx, block_offset]
 | 
			
		||||
        latent, rope_vals = _dequantize_fp8_ds_mla_entry(
 | 
			
		||||
            cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
 | 
			
		||||
        dequant_kv_c[token_idx] = latent
 | 
			
		||||
        dequant_k_pe[token_idx] = rope_vals
 | 
			
		||||
 | 
			
		||||
    return dequant_kv_c, dequant_k_pe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sparse_backend_metadata_registration():
 | 
			
		||||
    backend = FlashMLASparseBackend
 | 
			
		||||
 | 
			
		||||
    assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
 | 
			
		||||
    assert backend.get_metadata_cls() is FlashMLASparseMetadata
 | 
			
		||||
    assert backend.get_impl_cls() is FlashMLASparseImpl
 | 
			
		||||
 | 
			
		||||
    dtype_list = backend.get_supported_dtypes()
 | 
			
		||||
    assert torch.bfloat16 in dtype_list
 | 
			
		||||
 | 
			
		||||
    shape = backend.get_kv_cache_shape(num_blocks=2,
 | 
			
		||||
                                       block_size=64,
 | 
			
		||||
                                       num_kv_heads=1,
 | 
			
		||||
                                       head_size=576)
 | 
			
		||||
    assert shape == (2, 64, 576)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sparse_decode_metadata_filters_prefill_indices():
 | 
			
		||||
    prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
 | 
			
		||||
    metadata = FlashMLASparseDecodeAndContextMetadata(
 | 
			
		||||
        scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
 | 
			
		||||
        num_splits=torch.tensor([1, 1], dtype=torch.int32),
 | 
			
		||||
        cache_lens=torch.tensor([10, 12], dtype=torch.int32),
 | 
			
		||||
        prefill_context_lengths=prefill_context_lengths,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
    context_indices, new_token_indices = metadata.filter_prefill_indices(
 | 
			
		||||
        indices)
 | 
			
		||||
 | 
			
		||||
    expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
 | 
			
		||||
                                    dtype=torch.int32)
 | 
			
		||||
    expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
 | 
			
		||||
                                       dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
    assert torch.equal(context_indices, expected_context)
 | 
			
		||||
    assert torch.equal(new_token_indices, expected_new_tokens)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sparse_impl_zero_fills_when_metadata_missing():
 | 
			
		||||
    impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
 | 
			
		||||
    dummy_layer = object()
 | 
			
		||||
    q = torch.zeros((2, 1, 3))
 | 
			
		||||
    k_c = torch.zeros((2, 3))
 | 
			
		||||
    k_pe = torch.zeros((2, 1, 1))
 | 
			
		||||
    kv_cache = torch.zeros((1, 1, 1))
 | 
			
		||||
    output = torch.ones((2, 4))
 | 
			
		||||
 | 
			
		||||
    result = FlashMLASparseImpl.forward(impl,
 | 
			
		||||
                                        dummy_layer,
 | 
			
		||||
                                        q,
 | 
			
		||||
                                        k_c,
 | 
			
		||||
                                        k_pe,
 | 
			
		||||
                                        kv_cache,
 | 
			
		||||
                                        attn_metadata=None,
 | 
			
		||||
                                        output=output)
 | 
			
		||||
 | 
			
		||||
    assert result is output
 | 
			
		||||
    assert torch.all(result == 0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
 | 
			
		||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
 | 
			
		||||
def test_sparse_backend_decode_correctness(dist_init, batch_name,
 | 
			
		||||
                                           kv_cache_dtype):
 | 
			
		||||
    if not torch.cuda.is_available():
 | 
			
		||||
        pytest.skip("CUDA is required for sparse MLA decode test")
 | 
			
		||||
 | 
			
		||||
    device = torch.device("cuda")
 | 
			
		||||
    dtype = torch.bfloat16
 | 
			
		||||
 | 
			
		||||
    batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
 | 
			
		||||
 | 
			
		||||
    # Model hyper-parameters (kept intentionally small for the unit test)
 | 
			
		||||
    num_heads = 128
 | 
			
		||||
    kv_lora_rank = 512
 | 
			
		||||
    qk_nope_head_dim = 128
 | 
			
		||||
    qk_rope_head_dim = 64
 | 
			
		||||
    v_head_dim = 128
 | 
			
		||||
    head_size = kv_lora_rank + qk_rope_head_dim
 | 
			
		||||
    topk_tokens = 2048
 | 
			
		||||
 | 
			
		||||
    max_seqlen = max(batch_spec.seq_lens)
 | 
			
		||||
    total_cache_tokens = sum(batch_spec.seq_lens)
 | 
			
		||||
    block_size = 64
 | 
			
		||||
 | 
			
		||||
    vllm_config = create_vllm_config(
 | 
			
		||||
        model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
 | 
			
		||||
        max_model_len=max_seqlen,
 | 
			
		||||
        num_gpu_blocks=max(2048,
 | 
			
		||||
                           cdiv(total_cache_tokens, block_size) + 1),
 | 
			
		||||
        block_size=block_size)
 | 
			
		||||
    model_config = vllm_config.model_config
 | 
			
		||||
    model_config.hf_config = SimpleNamespace(
 | 
			
		||||
        attn_module_list_cfg=[{
 | 
			
		||||
            "topk_tokens": topk_tokens
 | 
			
		||||
        }])
 | 
			
		||||
    model_config.hf_text_config = SimpleNamespace(
 | 
			
		||||
        q_lora_rank=None,
 | 
			
		||||
        kv_lora_rank=kv_lora_rank,
 | 
			
		||||
        qk_nope_head_dim=qk_nope_head_dim,
 | 
			
		||||
        qk_rope_head_dim=qk_rope_head_dim,
 | 
			
		||||
        v_head_dim=v_head_dim,
 | 
			
		||||
        model_type="deepseek_v2",
 | 
			
		||||
    )
 | 
			
		||||
    model_config.dtype = dtype
 | 
			
		||||
    model_config.get_num_attention_heads = MethodType(
 | 
			
		||||
        lambda self, parallel_config: num_heads, model_config)
 | 
			
		||||
    model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
 | 
			
		||||
                                               model_config)
 | 
			
		||||
    model_config.get_head_size = MethodType(lambda self: head_size,
 | 
			
		||||
                                            model_config)
 | 
			
		||||
    model_config.get_sliding_window = MethodType(lambda self: None,
 | 
			
		||||
                                                 model_config)
 | 
			
		||||
 | 
			
		||||
    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
 | 
			
		||||
 | 
			
		||||
    torch.manual_seed(0)
 | 
			
		||||
 | 
			
		||||
    scale = 1.0 / math.sqrt(head_size)
 | 
			
		||||
 | 
			
		||||
    # Shared MLA projection weights to keep reference and backend in sync
 | 
			
		||||
    W_UK = torch.randn(kv_lora_rank,
 | 
			
		||||
                       num_heads,
 | 
			
		||||
                       qk_nope_head_dim,
 | 
			
		||||
                       dtype=dtype,
 | 
			
		||||
                       device=device)
 | 
			
		||||
    W_UV = torch.randn(kv_lora_rank,
 | 
			
		||||
                       num_heads,
 | 
			
		||||
                       v_head_dim,
 | 
			
		||||
                       dtype=dtype,
 | 
			
		||||
                       device=device)
 | 
			
		||||
 | 
			
		||||
    # Build synthetic decode-only workload
 | 
			
		||||
    seq_lens = batch_spec.seq_lens
 | 
			
		||||
    query_lens = batch_spec.query_lens
 | 
			
		||||
 | 
			
		||||
    all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
 | 
			
		||||
    kv_c_contexts, k_pe_contexts = [], []
 | 
			
		||||
    reference_outputs = []
 | 
			
		||||
 | 
			
		||||
    kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
 | 
			
		||||
 | 
			
		||||
    for i in range(batch_spec.batch_size):
 | 
			
		||||
        s_len = seq_lens[i]
 | 
			
		||||
        q_len = query_lens[i]
 | 
			
		||||
        ctx_len = s_len - q_len
 | 
			
		||||
 | 
			
		||||
        q_c = torch.rand(q_len,
 | 
			
		||||
                         num_heads,
 | 
			
		||||
                         qk_nope_head_dim + qk_rope_head_dim,
 | 
			
		||||
                         dtype=dtype,
 | 
			
		||||
                         device=device)
 | 
			
		||||
        kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
 | 
			
		||||
        k_pe_full = torch.rand(s_len,
 | 
			
		||||
                               1,
 | 
			
		||||
                               qk_rope_head_dim,
 | 
			
		||||
                               dtype=dtype,
 | 
			
		||||
                               device=device)
 | 
			
		||||
 | 
			
		||||
        kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
 | 
			
		||||
            kv_c_full,
 | 
			
		||||
            k_pe_full.squeeze(1),
 | 
			
		||||
            block_size=vllm_config.cache_config.block_size,
 | 
			
		||||
            scale=kv_cache_scale,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
 | 
			
		||||
        ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
 | 
			
		||||
        q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
 | 
			
		||||
 | 
			
		||||
        k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
 | 
			
		||||
        k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
 | 
			
		||||
        v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
 | 
			
		||||
 | 
			
		||||
        attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
 | 
			
		||||
        causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
 | 
			
		||||
        attn_mask[:, ctx_len:] = causal_mask
 | 
			
		||||
 | 
			
		||||
        q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
 | 
			
		||||
        k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
 | 
			
		||||
        v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
 | 
			
		||||
 | 
			
		||||
        sdpa_out = torch.nn.functional.scaled_dot_product_attention(
 | 
			
		||||
            q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
 | 
			
		||||
        sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
 | 
			
		||||
 | 
			
		||||
        sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
 | 
			
		||||
        reference_outputs.append(sdpa_out.flatten(start_dim=-2))
 | 
			
		||||
 | 
			
		||||
        all_q_vllm.append(q_c)
 | 
			
		||||
        all_kv_c_vllm.append(kv_c_full[ctx_len:])
 | 
			
		||||
        all_k_pe_vllm.append(k_pe_full[ctx_len:])
 | 
			
		||||
        kv_c_contexts.append(kv_c_full[:ctx_len + 1])
 | 
			
		||||
        k_pe_contexts.append(k_pe_full[:ctx_len + 1])
 | 
			
		||||
 | 
			
		||||
    query_vllm = torch.cat(all_q_vllm, dim=0)
 | 
			
		||||
    kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
 | 
			
		||||
    k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
 | 
			
		||||
    sdpa_reference = torch.cat(reference_outputs, dim=0)
 | 
			
		||||
 | 
			
		||||
    vllm_config.cache_config.cache_dtype = kv_cache_dtype
 | 
			
		||||
 | 
			
		||||
    common_attn_metadata = create_common_attn_metadata(
 | 
			
		||||
        batch_spec,
 | 
			
		||||
        vllm_config.cache_config.block_size,
 | 
			
		||||
        device,
 | 
			
		||||
        arange_block_indices=True)
 | 
			
		||||
 | 
			
		||||
    kv_cache = create_and_prepopulate_kv_cache(
 | 
			
		||||
        kv_c_contexts=kv_c_contexts,
 | 
			
		||||
        k_pe_contexts=k_pe_contexts,
 | 
			
		||||
        block_size=vllm_config.cache_config.block_size,
 | 
			
		||||
        head_size=head_size,
 | 
			
		||||
        dtype=dtype,
 | 
			
		||||
        device=device,
 | 
			
		||||
        num_blocks=vllm_config.cache_config.num_gpu_blocks,
 | 
			
		||||
        common_attn_metadata=common_attn_metadata,
 | 
			
		||||
        randomize_blocks=False,
 | 
			
		||||
        kv_cache_dtype=vllm_config.cache_config.cache_dtype,
 | 
			
		||||
        scale=kv_cache_scale,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    builder_cls = FlashMLASparseBackend.get_builder_cls()
 | 
			
		||||
    builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
 | 
			
		||||
    metadata = builder.build(common_prefix_len=0,
 | 
			
		||||
                             common_attn_metadata=common_attn_metadata)
 | 
			
		||||
 | 
			
		||||
    starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
 | 
			
		||||
                        dtype=np.int32)
 | 
			
		||||
    seg_lengths = np.diff(starts)
 | 
			
		||||
    positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
 | 
			
		||||
        starts[:-1], seg_lengths)
 | 
			
		||||
    seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
 | 
			
		||||
    prefix_lengths = seq_lengths - seg_lengths
 | 
			
		||||
    positions += np.repeat(prefix_lengths, seg_lengths)
 | 
			
		||||
 | 
			
		||||
    pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
 | 
			
		||||
    topk = metadata.topk_tokens
 | 
			
		||||
    debug_indices = torch.arange(topk, device=device,
 | 
			
		||||
                                 dtype=torch.int32).unsqueeze(0)
 | 
			
		||||
    token_positions = pos_gpu.unsqueeze(1)
 | 
			
		||||
    causal_mask = (debug_indices <= token_positions)
 | 
			
		||||
    debug_indices = torch.where(causal_mask, debug_indices,
 | 
			
		||||
                                torch.full_like(debug_indices, -1))
 | 
			
		||||
 | 
			
		||||
    # FlashMLASparseImpl now reads top-k indices from the indexer-provided
 | 
			
		||||
    # buffer, so emulate that contract with a simple namespace mock.
 | 
			
		||||
    debug_indices = debug_indices.expand(metadata.num_actual_tokens,
 | 
			
		||||
                                         -1).clone()
 | 
			
		||||
    mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
 | 
			
		||||
 | 
			
		||||
    ok, reason = flashmla.is_flashmla_supported()
 | 
			
		||||
    if not ok:
 | 
			
		||||
        pytest.skip(reason)
 | 
			
		||||
 | 
			
		||||
    kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
 | 
			
		||||
    kv_b_proj_weight = kv_b_proj_weight.view(
 | 
			
		||||
        kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
 | 
			
		||||
 | 
			
		||||
    mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
 | 
			
		||||
                                          output_size=num_heads *
 | 
			
		||||
                                          (qk_nope_head_dim + v_head_dim),
 | 
			
		||||
                                          bias=False).to(device=device,
 | 
			
		||||
                                                         dtype=dtype)
 | 
			
		||||
    mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
 | 
			
		||||
 | 
			
		||||
    impl_cls = FlashMLASparseBackend.get_impl_cls()
 | 
			
		||||
    impl = impl_cls(num_heads=num_heads,
 | 
			
		||||
                    head_size=head_size,
 | 
			
		||||
                    scale=scale,
 | 
			
		||||
                    num_kv_heads=1,
 | 
			
		||||
                    alibi_slopes=None,
 | 
			
		||||
                    sliding_window=None,
 | 
			
		||||
                    kv_cache_dtype=vllm_config.cache_config.cache_dtype,
 | 
			
		||||
                    logits_soft_cap=None,
 | 
			
		||||
                    attn_type="decoder",
 | 
			
		||||
                    kv_sharing_target_layer_name=None,
 | 
			
		||||
                    q_lora_rank=None,
 | 
			
		||||
                    kv_lora_rank=kv_lora_rank,
 | 
			
		||||
                    qk_nope_head_dim=qk_nope_head_dim,
 | 
			
		||||
                    qk_rope_head_dim=qk_rope_head_dim,
 | 
			
		||||
                    qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
 | 
			
		||||
                    v_head_dim=v_head_dim,
 | 
			
		||||
                    kv_b_proj=mock_kv_b_proj,
 | 
			
		||||
                    indexer=mock_indexer)
 | 
			
		||||
 | 
			
		||||
    impl.process_weights_after_loading(dtype)
 | 
			
		||||
 | 
			
		||||
    layer = MockAttentionLayer(device)
 | 
			
		||||
    out_buffer = torch.empty(metadata.num_actual_tokens,
 | 
			
		||||
                             num_heads * v_head_dim,
 | 
			
		||||
                             dtype=dtype,
 | 
			
		||||
                             device=device)
 | 
			
		||||
 | 
			
		||||
    backend_output = impl.forward(layer,
 | 
			
		||||
                                  query_vllm,
 | 
			
		||||
                                  kv_c_vllm,
 | 
			
		||||
                                  k_pe_vllm,
 | 
			
		||||
                                  kv_cache,
 | 
			
		||||
                                  metadata,
 | 
			
		||||
                                  output=out_buffer)
 | 
			
		||||
 | 
			
		||||
    assert backend_output.shape == sdpa_reference.shape
 | 
			
		||||
    assert backend_output.dtype == sdpa_reference.dtype
 | 
			
		||||
    assert torch.isfinite(backend_output).all()
 | 
			
		||||
 | 
			
		||||
    torch.testing.assert_close(backend_output,
 | 
			
		||||
                               sdpa_reference,
 | 
			
		||||
                               rtol=0.5,
 | 
			
		||||
                               atol=0.5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(
 | 
			
		||||
    "seq_lens,max_buf,start,expected",
 | 
			
		||||
    [
 | 
			
		||||
        # Basic split: totals per chunk ≤ max_buf
 | 
			
		||||
        (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
 | 
			
		||||
        # Non-zero start index
 | 
			
		||||
        (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
 | 
			
		||||
        # Exact fits should split between items when adding the next would
 | 
			
		||||
        # overflow
 | 
			
		||||
        (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
 | 
			
		||||
        # All requests fit in a single chunk
 | 
			
		||||
        (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
 | 
			
		||||
        # Large buffer with non-zero start
 | 
			
		||||
        (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
 | 
			
		||||
    out = split_prefill_chunks(seq_lens, max_buf, start)
 | 
			
		||||
    assert out == expected
 | 
			
		||||
@ -168,7 +168,6 @@ def create_standard_kv_cache_spec(
 | 
			
		||||
            vllm_config.parallel_config),
 | 
			
		||||
        head_size=vllm_config.model_config.get_head_size(),
 | 
			
		||||
        dtype=vllm_config.model_config.dtype,
 | 
			
		||||
        use_mla=vllm_config.model_config.use_mla,
 | 
			
		||||
        sliding_window=vllm_config.model_config.get_sliding_window(),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,8 @@ from vllm.v1.core.kv_cache_utils import (
 | 
			
		||||
    make_block_hash_with_group_id)
 | 
			
		||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
 | 
			
		||||
                                        KVCacheGroupSpec, KVCacheSpec,
 | 
			
		||||
                                        KVCacheTensor, SlidingWindowSpec,
 | 
			
		||||
                                        KVCacheTensor, MLAAttentionSpec,
 | 
			
		||||
                                        SlidingWindowSpec,
 | 
			
		||||
                                        UniformTypeKVCacheSpecs)
 | 
			
		||||
from vllm.v1.metrics.stats import PrefixCacheStats
 | 
			
		||||
from vllm.v1.request import Request
 | 
			
		||||
@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16,
 | 
			
		||||
                      num_kv_heads=2,
 | 
			
		||||
                      head_size=64,
 | 
			
		||||
                      dtype=torch.float32,
 | 
			
		||||
                      use_mla=False,
 | 
			
		||||
                      sliding_window=None):
 | 
			
		||||
    return FullAttentionSpec(block_size=block_size,
 | 
			
		||||
                             num_kv_heads=num_kv_heads,
 | 
			
		||||
                             head_size=head_size,
 | 
			
		||||
                             dtype=dtype,
 | 
			
		||||
                             use_mla=use_mla,
 | 
			
		||||
                             sliding_window=sliding_window)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16,
 | 
			
		||||
                            num_kv_heads=2,
 | 
			
		||||
                            head_size=64,
 | 
			
		||||
                            dtype=torch.float32,
 | 
			
		||||
                            use_mla=False,
 | 
			
		||||
                            sliding_window=1):
 | 
			
		||||
    return SlidingWindowSpec(block_size=block_size,
 | 
			
		||||
                             num_kv_heads=num_kv_heads,
 | 
			
		||||
                             head_size=head_size,
 | 
			
		||||
                             dtype=dtype,
 | 
			
		||||
                             use_mla=use_mla,
 | 
			
		||||
                             sliding_window=sliding_window)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -894,7 +891,6 @@ def test_merge_kv_cache_spec():
 | 
			
		||||
            num_kv_heads=full_spec.num_kv_heads,
 | 
			
		||||
            head_size=full_spec.head_size,
 | 
			
		||||
            dtype=full_spec.dtype,
 | 
			
		||||
            use_mla=full_spec.use_mla,
 | 
			
		||||
            sliding_window=1,
 | 
			
		||||
        ),
 | 
			
		||||
    ]
 | 
			
		||||
@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len,
 | 
			
		||||
            num_kv_heads=32,
 | 
			
		||||
            head_size=128,
 | 
			
		||||
            dtype=torch.float16,
 | 
			
		||||
            use_mla=False,
 | 
			
		||||
        )
 | 
			
		||||
    # Estimate the maximum model length, 16384 model_len need 8GB
 | 
			
		||||
    estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
 | 
			
		||||
@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config():
 | 
			
		||||
        num_kv_heads=32,
 | 
			
		||||
        head_size=128,
 | 
			
		||||
        dtype=torch.float16,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    sliding_window_spec = SlidingWindowSpec(
 | 
			
		||||
@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config():
 | 
			
		||||
        num_kv_heads=32,
 | 
			
		||||
        head_size=128,
 | 
			
		||||
        dtype=torch.float16,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
        sliding_window=1024,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config():
 | 
			
		||||
            KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def new_mla_spec(cache_dtype_str=None):
 | 
			
		||||
    return MLAAttentionSpec(block_size=16,
 | 
			
		||||
                            num_kv_heads=16,
 | 
			
		||||
                            head_size=64,
 | 
			
		||||
                            dtype=torch.float32,
 | 
			
		||||
                            cache_dtype_str=cache_dtype_str)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_merge_mla_spec():
 | 
			
		||||
    kv_cache_specs = [
 | 
			
		||||
        new_mla_spec(),
 | 
			
		||||
        new_mla_spec(),
 | 
			
		||||
    ]
 | 
			
		||||
    mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
 | 
			
		||||
    assert mla_spec == new_mla_spec()
 | 
			
		||||
 | 
			
		||||
    kv_cache_specs = [
 | 
			
		||||
        new_mla_spec(cache_dtype_str="fp8_ds_mla"),
 | 
			
		||||
        new_mla_spec(cache_dtype_str="fp8_ds_mla"),
 | 
			
		||||
    ]
 | 
			
		||||
    mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
 | 
			
		||||
    assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla")
 | 
			
		||||
 | 
			
		||||
    kv_cache_specs = [
 | 
			
		||||
        new_mla_spec(cache_dtype_str="fp8_ds_mla"),
 | 
			
		||||
        new_mla_spec(cache_dtype_str=None),
 | 
			
		||||
    ]
 | 
			
		||||
    with pytest.raises(AssertionError):
 | 
			
		||||
        kv_cache_specs[0].merge(kv_cache_specs)
 | 
			
		||||
 | 
			
		||||
    kv_cache_specs = [
 | 
			
		||||
        new_kv_cache_spec(),
 | 
			
		||||
        new_mla_spec(),
 | 
			
		||||
    ]
 | 
			
		||||
    with pytest.raises(AssertionError):
 | 
			
		||||
        kv_cache_specs[0].merge(kv_cache_specs)
 | 
			
		||||
 | 
			
		||||
    kv_cache_specs = [
 | 
			
		||||
        new_mla_spec(cache_dtype_str="fp8_ds_mla"),
 | 
			
		||||
        new_kv_cache_spec(),
 | 
			
		||||
    ]
 | 
			
		||||
    with pytest.raises(AssertionError):
 | 
			
		||||
        kv_cache_specs[0].merge(kv_cache_specs)
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
 | 
			
		||||
        kv_cache_groups=[
 | 
			
		||||
            KVCacheGroupSpec(
 | 
			
		||||
                ["layer"],
 | 
			
		||||
                FullAttentionSpec(block_size, 1, 1, torch.float32, False),
 | 
			
		||||
                FullAttentionSpec(block_size, 1, 1, torch.float32),
 | 
			
		||||
            )
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
 | 
			
		||||
        kv_cache_groups=[
 | 
			
		||||
            KVCacheGroupSpec(
 | 
			
		||||
                ["layer1"],
 | 
			
		||||
                FullAttentionSpec(block_size, 1, 1, torch.float32, False),
 | 
			
		||||
                FullAttentionSpec(block_size, 1, 1, torch.float32),
 | 
			
		||||
            ),
 | 
			
		||||
            KVCacheGroupSpec(
 | 
			
		||||
                ["layer2"],
 | 
			
		||||
@ -98,7 +98,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
 | 
			
		||||
                                  1,
 | 
			
		||||
                                  1,
 | 
			
		||||
                                  torch.float32,
 | 
			
		||||
                                  False,
 | 
			
		||||
                                  sliding_window=2 * block_size),
 | 
			
		||||
            ),
 | 
			
		||||
            KVCacheGroupSpec(
 | 
			
		||||
@ -107,7 +106,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
 | 
			
		||||
                                  1,
 | 
			
		||||
                                  1,
 | 
			
		||||
                                  torch.float32,
 | 
			
		||||
                                  False,
 | 
			
		||||
                                  sliding_window=2 * block_size),
 | 
			
		||||
            ),
 | 
			
		||||
        ],
 | 
			
		||||
@ -1338,7 +1336,6 @@ def test_eagle_with_sliding_window():
 | 
			
		||||
        head_size=1,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
        sliding_window=block_size,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
    manager = KVCacheManager(
 | 
			
		||||
        KVCacheConfig(
 | 
			
		||||
 | 
			
		||||
@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix():
 | 
			
		||||
        head_size=1,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
        attention_chunk_size=4,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
 | 
			
		||||
@ -100,7 +99,6 @@ def test_sliding_window_possible_cached_prefix():
 | 
			
		||||
        head_size=1,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
        sliding_window=4,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
 | 
			
		||||
@ -165,7 +163,6 @@ def test_chunked_local_attention_remove_skipped_blocks():
 | 
			
		||||
        head_size=1,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
        attention_chunk_size=4,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
 | 
			
		||||
@ -217,7 +214,6 @@ def test_sliding_window_remove_skipped_blocks():
 | 
			
		||||
        head_size=1,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
        sliding_window=4,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
 | 
			
		||||
@ -285,7 +281,6 @@ def test_get_num_blocks_to_allocate():
 | 
			
		||||
        head_size=1,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
        sliding_window=4,  # Placeholder value, not related to test result
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
 | 
			
		||||
@ -308,7 +303,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
 | 
			
		||||
        head_size=1,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
        attention_chunk_size=4,  # Placeholder value, not related to test result
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
 | 
			
		||||
from vllm.distributed import cleanup_dist_env_and_memory
 | 
			
		||||
from vllm.platforms import current_platform
 | 
			
		||||
 | 
			
		||||
MTP_SIMILARITY_RATE = 0.8
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_test_prompts(mm_enabled: bool):
 | 
			
		||||
    prompt_types = ["repeat", "sentence"]
 | 
			
		||||
@ -222,3 +224,66 @@ def test_eagle_correctness(
 | 
			
		||||
        del spec_llm
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        cleanup_dist_env_and_memory()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
 | 
			
		||||
    (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
 | 
			
		||||
    (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
 | 
			
		||||
],
 | 
			
		||||
                         ids=["mimo", "deepseek"])
 | 
			
		||||
def test_mtp_correctness(
 | 
			
		||||
    monkeypatch: pytest.MonkeyPatch,
 | 
			
		||||
    sampling_config: SamplingParams,
 | 
			
		||||
    model_setup: tuple[str, str, int],
 | 
			
		||||
    mm_enabled: bool,
 | 
			
		||||
):
 | 
			
		||||
    # Generate test prompts inside the function instead of using fixture
 | 
			
		||||
    test_prompts = get_test_prompts(mm_enabled)
 | 
			
		||||
    '''
 | 
			
		||||
    Compare the outputs of a original LLM and a speculative LLM
 | 
			
		||||
    should be the same when using MTP speculative decoding.
 | 
			
		||||
    model_setup: (method, model_name, tp_size)
 | 
			
		||||
    '''
 | 
			
		||||
    with monkeypatch.context() as m:
 | 
			
		||||
        m.setenv("VLLM_USE_V1", "1")
 | 
			
		||||
        m.setenv("VLLM_MLA_DISABLE", "1")
 | 
			
		||||
 | 
			
		||||
        method, model_name, tp_size = model_setup
 | 
			
		||||
 | 
			
		||||
        ref_llm = LLM(model=model_name,
 | 
			
		||||
                      max_model_len=2048,
 | 
			
		||||
                      tensor_parallel_size=tp_size,
 | 
			
		||||
                      trust_remote_code=True)
 | 
			
		||||
        ref_outputs = ref_llm.chat(test_prompts, sampling_config)
 | 
			
		||||
        del ref_llm
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        cleanup_dist_env_and_memory()
 | 
			
		||||
 | 
			
		||||
        spec_llm = LLM(
 | 
			
		||||
            model=model_name,
 | 
			
		||||
            trust_remote_code=True,
 | 
			
		||||
            tensor_parallel_size=tp_size,
 | 
			
		||||
            speculative_config={
 | 
			
		||||
                "method": method,
 | 
			
		||||
                "num_speculative_tokens": 1,
 | 
			
		||||
                "max_model_len": 2048,
 | 
			
		||||
            },
 | 
			
		||||
            max_model_len=2048,
 | 
			
		||||
        )
 | 
			
		||||
        spec_outputs = spec_llm.chat(test_prompts, sampling_config)
 | 
			
		||||
        matches = 0
 | 
			
		||||
        misses = 0
 | 
			
		||||
        for ref_output, spec_output in zip(ref_outputs, spec_outputs):
 | 
			
		||||
            if ref_output.outputs[0].text == spec_output.outputs[0].text:
 | 
			
		||||
                matches += 1
 | 
			
		||||
            else:
 | 
			
		||||
                misses += 1
 | 
			
		||||
                print(f"ref_output: {ref_output.outputs[0].text}")
 | 
			
		||||
                print(f"spec_output: {spec_output.outputs[0].text}")
 | 
			
		||||
 | 
			
		||||
        # Heuristic: expect at least 80% of the prompts to match exactly
 | 
			
		||||
        # Upon failure, inspect the outputs to check for inaccuracy.
 | 
			
		||||
        assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
 | 
			
		||||
        del spec_llm
 | 
			
		||||
        torch.cuda.empty_cache()
 | 
			
		||||
        cleanup_dist_env_and_memory()
 | 
			
		||||
 | 
			
		||||
@ -836,8 +836,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
 | 
			
		||||
        mock_spec = FullAttentionSpec(block_size=16,
 | 
			
		||||
                                      num_kv_heads=1,
 | 
			
		||||
                                      head_size=64,
 | 
			
		||||
                                      dtype=torch.float16,
 | 
			
		||||
                                      use_mla=False)
 | 
			
		||||
                                      dtype=torch.float16)
 | 
			
		||||
 | 
			
		||||
        mock_executor.get_kv_cache_specs.return_value = [{
 | 
			
		||||
            "default": mock_spec
 | 
			
		||||
 | 
			
		||||
@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
 | 
			
		||||
        time.sleep(self._hand_shake_latency)
 | 
			
		||||
        # These should've been done in register_kv_caches(), called by
 | 
			
		||||
        # gpu_model_runner. Here we just hardcode some dummy values.
 | 
			
		||||
        self.slot_size_bytes = 4096
 | 
			
		||||
        self.block_len = self.slot_size_bytes * self.block_size
 | 
			
		||||
        slot_size_bytes = 4096
 | 
			
		||||
        self.slot_size_per_layer = [slot_size_bytes]
 | 
			
		||||
        self.block_len_per_layer = [slot_size_bytes * self.block_size]
 | 
			
		||||
        self.num_blocks = 1
 | 
			
		||||
        self.dst_num_blocks[self.engine_id] = self.num_blocks
 | 
			
		||||
 | 
			
		||||
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
 | 
			
		||||
                agent_metadata=FakeNixlWrapper.AGENT_METADATA,
 | 
			
		||||
                kv_caches_base_addr=[0],
 | 
			
		||||
                num_blocks=1,
 | 
			
		||||
                block_len=self.block_len,
 | 
			
		||||
                block_lens=self.block_len_per_layer,
 | 
			
		||||
                attn_backend_name=self.backend_name,
 | 
			
		||||
                # `self.kv_cache_layout` is only forced to HND when vllm engine
 | 
			
		||||
                # is started. We mock HND here.
 | 
			
		||||
@ -485,8 +486,8 @@ class TestNixlHandshake:
 | 
			
		||||
            worker = connector.connector_worker
 | 
			
		||||
 | 
			
		||||
            # Minimal local registration params used by add_remote_agent
 | 
			
		||||
            worker.slot_size_bytes = 4096
 | 
			
		||||
            worker.block_len = worker.slot_size_bytes * worker.block_size
 | 
			
		||||
            worker.slot_size_per_layer = [4096]
 | 
			
		||||
            worker.block_len_per_layer = [4096 * worker.block_size]
 | 
			
		||||
            worker.num_blocks = 1
 | 
			
		||||
            worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
 | 
			
		||||
 | 
			
		||||
@ -498,7 +499,7 @@ class TestNixlHandshake:
 | 
			
		||||
                agent_metadata=FakeNixlWrapper.AGENT_METADATA,
 | 
			
		||||
                kv_caches_base_addr=[0],
 | 
			
		||||
                num_blocks=1,
 | 
			
		||||
                block_len=worker.block_len,
 | 
			
		||||
                block_lens=worker.block_len_per_layer,
 | 
			
		||||
                attn_backend_name=worker.backend_name,
 | 
			
		||||
                kv_cache_layout=mismatched_layout,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
 | 
			
		||||
        "target_attn_1": mock.MagicMock(),
 | 
			
		||||
        "target_attn_2": mock.MagicMock()
 | 
			
		||||
    }
 | 
			
		||||
    target_indx_layers: dict[str, mock.MagicMock] = {}
 | 
			
		||||
    # Draft model has one extra attention layer compared to target model
 | 
			
		||||
    all_attn_layers = {
 | 
			
		||||
        **target_attn_layers, "draft_extra_attn": mock.MagicMock()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    all_indx_layers: dict[str, mock.MagicMock] = {}
 | 
			
		||||
 | 
			
		||||
    # Make mock_get_layers return different values for each call
 | 
			
		||||
    mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
 | 
			
		||||
    mock_get_layers.side_effect = [
 | 
			
		||||
        target_attn_layers, target_indx_layers, all_attn_layers,
 | 
			
		||||
        all_indx_layers
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    # Setup mock for pp group to return the appropriate value for world size
 | 
			
		||||
    mock_pp_group = mock.MagicMock()
 | 
			
		||||
@ -658,6 +664,9 @@ def test_propose_tree(spec_token_tree):
 | 
			
		||||
    # Mock runner for attention metadata building.
 | 
			
		||||
    proposer.runner = mock.MagicMock()
 | 
			
		||||
    proposer.runner.attn_groups.append([mock.MagicMock()])
 | 
			
		||||
    proposer.runner.attn_groups[0][0].metadata_builders = [
 | 
			
		||||
        attn_metadata_builder
 | 
			
		||||
    ]
 | 
			
		||||
    proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
 | 
			
		||||
        attn_metadata_builder
 | 
			
		||||
    proposer._get_attention_metadata_builder = mock.MagicMock(
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										201
									
								
								tests/v1/spec_decode/test_mtp.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								tests/v1/spec_decode/test_mtp.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,201 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | 
			
		||||
 | 
			
		||||
from unittest import mock
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
 | 
			
		||||
                                      create_common_attn_metadata,
 | 
			
		||||
                                      create_standard_kv_cache_spec,
 | 
			
		||||
                                      get_attention_backend)
 | 
			
		||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
 | 
			
		||||
                         ParallelConfig, SchedulerConfig, SpeculativeConfig,
 | 
			
		||||
                         VllmConfig)
 | 
			
		||||
from vllm.config.load import LoadConfig
 | 
			
		||||
from vllm.model_executor.models.llama import LlamaForCausalLM
 | 
			
		||||
from vllm.platforms import current_platform
 | 
			
		||||
from vllm.v1.spec_decode.eagle import EagleProposer
 | 
			
		||||
 | 
			
		||||
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
 | 
			
		||||
    """Create an MTP proposer with unified model configuration."""
 | 
			
		||||
    model_config = ModelConfig(model=mimo_7b_dir,
 | 
			
		||||
                               runner="generate",
 | 
			
		||||
                               max_model_len=100,
 | 
			
		||||
                               trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
    speculative_config = SpeculativeConfig(
 | 
			
		||||
        target_model_config=model_config,
 | 
			
		||||
        target_parallel_config=ParallelConfig(),
 | 
			
		||||
        model=mimo_7b_dir,
 | 
			
		||||
        method="mtp",
 | 
			
		||||
        num_speculative_tokens=num_speculative_tokens,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    vllm_config = VllmConfig(
 | 
			
		||||
        model_config=model_config,
 | 
			
		||||
        cache_config=CacheConfig(),
 | 
			
		||||
        speculative_config=speculative_config,
 | 
			
		||||
        device_config=DeviceConfig(device=current_platform.device_type),
 | 
			
		||||
        parallel_config=ParallelConfig(),
 | 
			
		||||
        load_config=LoadConfig(),
 | 
			
		||||
        scheduler_config=SchedulerConfig())
 | 
			
		||||
 | 
			
		||||
    return EagleProposer(vllm_config=vllm_config,
 | 
			
		||||
                         device=current_platform.device_type)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
 | 
			
		||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
 | 
			
		||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
 | 
			
		||||
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
 | 
			
		||||
                                mock_get_pp_group):
 | 
			
		||||
    """Test MTP-specific model loading with unified model approach."""
 | 
			
		||||
 | 
			
		||||
    # Setup mocks
 | 
			
		||||
    mock_model = mock.MagicMock()
 | 
			
		||||
    mock_model.model.embed_tokens.weight.shape = (131072, 4096)
 | 
			
		||||
    mock_get_model.return_value = mock_model
 | 
			
		||||
 | 
			
		||||
    target_attn_layers = {"target_attn_1": mock.MagicMock()}
 | 
			
		||||
    all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
 | 
			
		||||
    target_indexer_layers: dict = {}
 | 
			
		||||
    all_indexer_layers: dict = {}
 | 
			
		||||
 | 
			
		||||
    mock_get_layers.side_effect = [
 | 
			
		||||
        target_attn_layers, target_indexer_layers, all_attn_layers,
 | 
			
		||||
        all_indexer_layers
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    mock_pp_group = mock.MagicMock()
 | 
			
		||||
    mock_pp_group.world_size = 1
 | 
			
		||||
    mock_get_pp_group.return_value = mock_pp_group
 | 
			
		||||
 | 
			
		||||
    # Create target model
 | 
			
		||||
    class _TargetModelStub(LlamaForCausalLM):
 | 
			
		||||
        model: mock.MagicMock
 | 
			
		||||
        lm_head: mock.MagicMock
 | 
			
		||||
 | 
			
		||||
    target_model = mock.create_autospec(_TargetModelStub, instance=True)
 | 
			
		||||
    target_model.model = mock.MagicMock()
 | 
			
		||||
    target_model.model.embed_tokens.weight.shape = (131072, 4096)
 | 
			
		||||
    target_model.lm_head = mock.MagicMock()
 | 
			
		||||
 | 
			
		||||
    # Create MTP proposer
 | 
			
		||||
    proposer = _create_mtp_proposer(num_speculative_tokens=4)
 | 
			
		||||
    proposer.load_model(target_model)
 | 
			
		||||
 | 
			
		||||
    # Verify MTP-specific behavior:
 | 
			
		||||
    # Model is loaded
 | 
			
		||||
    mock_get_model.assert_called_once()
 | 
			
		||||
    # MTP shares lm_head with target model
 | 
			
		||||
    assert proposer.model.lm_head == target_model.lm_head
 | 
			
		||||
    # MTP shares embed_tokens with target model
 | 
			
		||||
    assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
 | 
			
		||||
def test_mtp_propose(num_speculative_tokens, monkeypatch):
 | 
			
		||||
    """Test that MTP's forward method returns hidden states directly"""
 | 
			
		||||
 | 
			
		||||
    device = torch.device(current_platform.device_type)
 | 
			
		||||
    batch_size = 2
 | 
			
		||||
    seq_lens = [5, 3]
 | 
			
		||||
    total_tokens = sum(seq_lens)
 | 
			
		||||
    vocab_size = 100
 | 
			
		||||
 | 
			
		||||
    proposer = _create_mtp_proposer(num_speculative_tokens)
 | 
			
		||||
    hidden_size = proposer.hidden_size
 | 
			
		||||
 | 
			
		||||
    # Mock the MTP model to verify it returns hidden states directly
 | 
			
		||||
    model_mock = mock.MagicMock()
 | 
			
		||||
 | 
			
		||||
    # MTP returns hidden states directly
 | 
			
		||||
    if num_speculative_tokens == 1:
 | 
			
		||||
        model_mock.return_value = torch.zeros(total_tokens,
 | 
			
		||||
                                              hidden_size,
 | 
			
		||||
                                              device=device)
 | 
			
		||||
    else:
 | 
			
		||||
        # Multiple forward passes for multi-token speculation
 | 
			
		||||
        forward_returns = []
 | 
			
		||||
        for i in range(num_speculative_tokens):
 | 
			
		||||
            if i == 0:
 | 
			
		||||
                h_states = torch.zeros(total_tokens,
 | 
			
		||||
                                       hidden_size,
 | 
			
		||||
                                       device=device)
 | 
			
		||||
            else:
 | 
			
		||||
                h_states = torch.zeros(batch_size, hidden_size, device=device)
 | 
			
		||||
            forward_returns.append(h_states)
 | 
			
		||||
        model_mock.side_effect = forward_returns
 | 
			
		||||
 | 
			
		||||
    # Mock compute_logits
 | 
			
		||||
    def create_deterministic_logits(batch_size, vocab_size, token_offset):
 | 
			
		||||
        logits = torch.full((batch_size, vocab_size), -100.0, device=device)
 | 
			
		||||
        logits[:, token_offset] = 100.0
 | 
			
		||||
        return logits
 | 
			
		||||
 | 
			
		||||
    if num_speculative_tokens == 1:
 | 
			
		||||
        model_mock.compute_logits.return_value = create_deterministic_logits(
 | 
			
		||||
            batch_size, vocab_size, 42)
 | 
			
		||||
    else:
 | 
			
		||||
        logits_returns = [
 | 
			
		||||
            create_deterministic_logits(batch_size, vocab_size, 42 + i)
 | 
			
		||||
            for i in range(num_speculative_tokens)
 | 
			
		||||
        ]
 | 
			
		||||
        model_mock.compute_logits.side_effect = logits_returns
 | 
			
		||||
 | 
			
		||||
    proposer.model = model_mock
 | 
			
		||||
    proposer.attn_layer_names = ["layer.0"]
 | 
			
		||||
 | 
			
		||||
    # Prepare inputs
 | 
			
		||||
    batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
 | 
			
		||||
    common_attn_metadata = create_common_attn_metadata(batch_spec,
 | 
			
		||||
                                                       block_size=16,
 | 
			
		||||
                                                       device=device)
 | 
			
		||||
 | 
			
		||||
    target_token_ids = torch.randint(0,
 | 
			
		||||
                                     vocab_size, (total_tokens, ),
 | 
			
		||||
                                     device=device)
 | 
			
		||||
    target_positions = torch.cat([
 | 
			
		||||
        torch.arange(seq_lens[0], device=device),
 | 
			
		||||
        torch.arange(seq_lens[1], device=device)
 | 
			
		||||
    ])
 | 
			
		||||
    target_hidden_states = torch.randn(total_tokens,
 | 
			
		||||
                                       hidden_size,
 | 
			
		||||
                                       device=device)
 | 
			
		||||
    next_token_ids = torch.randint(0,
 | 
			
		||||
                                   vocab_size, (batch_size, ),
 | 
			
		||||
                                   dtype=torch.int32,
 | 
			
		||||
                                   device=device)
 | 
			
		||||
    sampling_metadata = mock.MagicMock()
 | 
			
		||||
 | 
			
		||||
    # Setup attention metadata
 | 
			
		||||
    attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
 | 
			
		||||
 | 
			
		||||
    attn_metadata_builder = attn_metadata_builder_cls(
 | 
			
		||||
        kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
 | 
			
		||||
        layer_names=proposer.attn_layer_names,
 | 
			
		||||
        vllm_config=proposer.vllm_config,
 | 
			
		||||
        device=device,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    proposer.runner = mock.MagicMock()
 | 
			
		||||
    proposer.attn_metadata_builder = attn_metadata_builder
 | 
			
		||||
 | 
			
		||||
    # Run propose
 | 
			
		||||
    result = proposer.propose(target_token_ids=target_token_ids,
 | 
			
		||||
                              target_positions=target_positions,
 | 
			
		||||
                              target_hidden_states=target_hidden_states,
 | 
			
		||||
                              next_token_ids=next_token_ids,
 | 
			
		||||
                              last_token_indices=None,
 | 
			
		||||
                              common_attn_metadata=common_attn_metadata,
 | 
			
		||||
                              sampling_metadata=sampling_metadata)
 | 
			
		||||
 | 
			
		||||
    # Verify the model was called correctly
 | 
			
		||||
    assert model_mock.called
 | 
			
		||||
    # Verify output shape
 | 
			
		||||
    assert result.shape == (batch_size, num_speculative_tokens)
 | 
			
		||||
@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner):
 | 
			
		||||
            runner.parallel_config),
 | 
			
		||||
        head_size=runner.model_config.get_head_size(),
 | 
			
		||||
        dtype=runner.kv_cache_dtype,
 | 
			
		||||
        use_mla=False,
 | 
			
		||||
    )
 | 
			
		||||
    tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
 | 
			
		||||
    kv_cache_config = KVCacheConfig(
 | 
			
		||||
 | 
			
		||||
@ -1678,6 +1678,15 @@ def cp_gather_cache(src_cache: torch.Tensor,
 | 
			
		||||
                                           cu_seq_lens, batch_size, seq_starts)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
 | 
			
		||||
                              slot_mapping: torch.Tensor,
 | 
			
		||||
                              quant_block_size: int,
 | 
			
		||||
                              kv_cache_dtype: str) -> None:
 | 
			
		||||
    torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
 | 
			
		||||
                                                     quant_block_size,
 | 
			
		||||
                                                     kv_cache_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_device_attribute(attribute: int, device: int) -> int:
 | 
			
		||||
    return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -70,6 +70,7 @@ class AttentionBackend(ABC):
 | 
			
		||||
        block_size: int,
 | 
			
		||||
        num_kv_heads: int,
 | 
			
		||||
        head_size: int,
 | 
			
		||||
        cache_dtype_str: str = "auto",
 | 
			
		||||
    ) -> Tuple[int, ...]:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -95,6 +95,7 @@ class Attention(nn.Module, AttentionLayerBase):
 | 
			
		||||
        logits_soft_cap: Optional[float] = None,
 | 
			
		||||
        per_layer_sliding_window: Optional[int] = None,
 | 
			
		||||
        use_mla: bool = False,
 | 
			
		||||
        use_sparse: bool = False,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        attn_type: str = AttentionType.DECODER,
 | 
			
		||||
        kv_sharing_target_layer_name: Optional[str] = None,
 | 
			
		||||
@ -155,6 +156,7 @@ class Attention(nn.Module, AttentionLayerBase):
 | 
			
		||||
        self._o_scale_float: Optional[float] = None
 | 
			
		||||
 | 
			
		||||
        self.use_mla = use_mla
 | 
			
		||||
        self.use_sparse = use_sparse
 | 
			
		||||
        self.num_heads = num_heads
 | 
			
		||||
        self.head_size = head_size
 | 
			
		||||
        self.num_kv_heads = num_kv_heads
 | 
			
		||||
@ -187,7 +189,8 @@ class Attention(nn.Module, AttentionLayerBase):
 | 
			
		||||
                                                 kv_cache_dtype,
 | 
			
		||||
                                                 block_size,
 | 
			
		||||
                                                 use_mla=use_mla,
 | 
			
		||||
                                                 has_sink=self.has_sink)
 | 
			
		||||
                                                 has_sink=self.has_sink,
 | 
			
		||||
                                                 use_sparse=use_sparse)
 | 
			
		||||
        else:
 | 
			
		||||
            self.attn_backend = attn_backend
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
# SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
 | 
			
		||||
import functools
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from typing import ClassVar, List, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend,
 | 
			
		||||
from vllm.attention.selector import get_attn_backend
 | 
			
		||||
from vllm.config import CacheConfig, QuantizationConfig
 | 
			
		||||
from vllm.v1.attention.backends.utils import (
 | 
			
		||||
    CommonAttentionMetadata, make_local_attention_virtual_batches,
 | 
			
		||||
    subclass_attention_backend)
 | 
			
		||||
    AttentionCGSupport, CommonAttentionMetadata,
 | 
			
		||||
    make_local_attention_virtual_batches, subclass_attention_backend)
 | 
			
		||||
 | 
			
		||||
from ..layer import Attention
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,8 @@ def create_chunked_local_attention_backend(
 | 
			
		||||
    underlying_builder = underlying_attn_backend.get_builder_cls()
 | 
			
		||||
 | 
			
		||||
    class ChunkedLocalAttentionBuilder(underlying_builder):  # type: ignore
 | 
			
		||||
        cudagraph_support: ClassVar[AttentionCGSupport] = \
 | 
			
		||||
            AttentionCGSupport.NEVER
 | 
			
		||||
 | 
			
		||||
        def build(self,
 | 
			
		||||
                  common_prefix_len: int,
 | 
			
		||||
 | 
			
		||||
@ -138,3 +138,208 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
 | 
			
		||||
    out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
 | 
			
		||||
    out = cp_group.reduce_scatter(out, dim=1)
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _pack_seq_kernel(
 | 
			
		||||
        x_ptr,  # [N, D]
 | 
			
		||||
        out_ptr,  # [B, Lmax, D]
 | 
			
		||||
        lengths_ptr,  # *i32, [B]
 | 
			
		||||
        N: tl.constexpr,
 | 
			
		||||
        D: tl.constexpr,
 | 
			
		||||
        Lmax: tl.constexpr,
 | 
			
		||||
        PAD_VALUE: tl.constexpr,
 | 
			
		||||
        BLOCK_T: tl.constexpr,  # timesteps per program
 | 
			
		||||
        BLOCK_D: tl.constexpr  # features per program
 | 
			
		||||
):
 | 
			
		||||
    pid_b = tl.program_id(0)  # batch id
 | 
			
		||||
    pid_t = tl.program_id(1)  # block over time dimension
 | 
			
		||||
    pid_d = tl.program_id(2)  # block over feature dimension
 | 
			
		||||
    off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)  # [BLOCK_T]
 | 
			
		||||
    off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)  # [BLOCK_D]
 | 
			
		||||
 | 
			
		||||
    # Compute start index and sequence length from cumulative lengths
 | 
			
		||||
    in_start = 0
 | 
			
		||||
    for i in range(pid_b):
 | 
			
		||||
        in_start += tl.load(lengths_ptr + i)
 | 
			
		||||
    seq_len = tl.load(lengths_ptr + pid_b)
 | 
			
		||||
 | 
			
		||||
    # valid time positions for this block
 | 
			
		||||
    t_mask = off_t < Lmax
 | 
			
		||||
 | 
			
		||||
    # compute input row indices for valid (b, t)
 | 
			
		||||
    in_row = in_start + off_t
 | 
			
		||||
    valid_row = (off_t < seq_len) & t_mask
 | 
			
		||||
 | 
			
		||||
    # Pointers
 | 
			
		||||
    # x_ptr: row-major [N, D]
 | 
			
		||||
    x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
 | 
			
		||||
 | 
			
		||||
    # out_ptr: row-major [B, Lmax, D]
 | 
			
		||||
    out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
 | 
			
		||||
                                                   None] * D + off_d[None, :]
 | 
			
		||||
 | 
			
		||||
    # Initialize with PAD (cast will occur as needed based on out_ptr dtype)
 | 
			
		||||
    d_mask = off_d[None, :] < D
 | 
			
		||||
    pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
 | 
			
		||||
    tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
 | 
			
		||||
 | 
			
		||||
    # Load & write only where within seq_len
 | 
			
		||||
    x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
 | 
			
		||||
    tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pack_seq_triton(x: torch.Tensor,
 | 
			
		||||
                    lengths: torch.Tensor,
 | 
			
		||||
                    pad_value: float = -float('inf'),
 | 
			
		||||
                    block_t: int = 64,
 | 
			
		||||
                    block_d: int = 64) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Pack sequences of different lengths into a batched tensor.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        x: [N, ...] - input tensor where N is total number of tokens
 | 
			
		||||
        lengths: [B] - sequence lengths for each batch
 | 
			
		||||
        pad_value: value to use for padding
 | 
			
		||||
        block_t: block size for time dimension
 | 
			
		||||
        block_d: block size for feature dimension
 | 
			
		||||
        
 | 
			
		||||
    Returns:
 | 
			
		||||
        packed: [B, Lmax, ...] - packed tensor
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Handle multi-dimensional input by reshaping to (N, -1)
 | 
			
		||||
    original_shape = x.shape
 | 
			
		||||
    if len(original_shape) > 2:
 | 
			
		||||
        N = original_shape[0]
 | 
			
		||||
        x_reshaped = x.reshape(N, -1)
 | 
			
		||||
        D = x_reshaped.shape[1]
 | 
			
		||||
    else:
 | 
			
		||||
        N, D = x.shape
 | 
			
		||||
        x_reshaped = x
 | 
			
		||||
 | 
			
		||||
    B = lengths.numel()
 | 
			
		||||
    Lmax = int(lengths.max().item())
 | 
			
		||||
 | 
			
		||||
    # Starts are computed inside the kernel from lengths
 | 
			
		||||
 | 
			
		||||
    out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
    grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
 | 
			
		||||
    _pack_seq_kernel[grid](x_reshaped,
 | 
			
		||||
                           out,
 | 
			
		||||
                           lengths.int(),
 | 
			
		||||
                           N,
 | 
			
		||||
                           D,
 | 
			
		||||
                           Lmax,
 | 
			
		||||
                           PAD_VALUE=float(pad_value),
 | 
			
		||||
                           BLOCK_T=block_t,
 | 
			
		||||
                           BLOCK_D=block_d,
 | 
			
		||||
                           num_warps=4,
 | 
			
		||||
                           num_stages=2)
 | 
			
		||||
 | 
			
		||||
    # Reshape output back to original dimensions (except first dimension)
 | 
			
		||||
    if len(original_shape) > 2:
 | 
			
		||||
        output_shape = (B, Lmax) + original_shape[1:]
 | 
			
		||||
        out = out.reshape(output_shape)
 | 
			
		||||
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@triton.jit
 | 
			
		||||
def _unpack_seq_triton_kernel(
 | 
			
		||||
        packed_ptr,  # [B, Lmax, D]
 | 
			
		||||
        out_ptr,  # [N, D]
 | 
			
		||||
        lengths_ptr,  # *i32, [B]
 | 
			
		||||
        B: tl.constexpr,
 | 
			
		||||
        Lmax: tl.constexpr,
 | 
			
		||||
        D: tl.constexpr,
 | 
			
		||||
        BLOCK_T: tl.constexpr,  # timesteps per program
 | 
			
		||||
        BLOCK_D: tl.constexpr  # features per program
 | 
			
		||||
):
 | 
			
		||||
    pid_b = tl.program_id(0)  # batch id
 | 
			
		||||
    pid_t = tl.program_id(1)  # block over time dimension
 | 
			
		||||
    pid_d = tl.program_id(2)  # block over feature dimension
 | 
			
		||||
    off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T)  # [BLOCK_T]
 | 
			
		||||
    off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)  # [BLOCK_D]
 | 
			
		||||
 | 
			
		||||
    # bounds: compute start from cumulative lengths
 | 
			
		||||
    in_start = 0
 | 
			
		||||
    for i in range(pid_b):
 | 
			
		||||
        in_start += tl.load(lengths_ptr + i)
 | 
			
		||||
    seq_len = tl.load(lengths_ptr + pid_b)
 | 
			
		||||
 | 
			
		||||
    # valid time positions for this block
 | 
			
		||||
    t_mask = off_t < Lmax
 | 
			
		||||
    valid_row = (off_t < seq_len) & t_mask
 | 
			
		||||
 | 
			
		||||
    # compute output row indices for valid (b, t)
 | 
			
		||||
    out_row = in_start + off_t
 | 
			
		||||
 | 
			
		||||
    # Pointers
 | 
			
		||||
    # packed_ptr: row-major [B, Lmax, D]
 | 
			
		||||
    packed_row_ptr = packed_ptr + (pid_b * Lmax +
 | 
			
		||||
                                   off_t)[:, None] * D + off_d[None, :]
 | 
			
		||||
 | 
			
		||||
    # out_ptr: row-major [N, D]
 | 
			
		||||
    out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
 | 
			
		||||
 | 
			
		||||
    # Load from packed tensor and store to output
 | 
			
		||||
    d_mask = off_d[None, :] < D
 | 
			
		||||
    packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
 | 
			
		||||
    tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unpack_seq_triton(packed_tensor: torch.Tensor,
 | 
			
		||||
                      lengths: torch.Tensor,
 | 
			
		||||
                      block_t: int = 64,
 | 
			
		||||
                      block_d: int = 64) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Unpack a packed decode query tensor back to the original format.
 | 
			
		||||
    Efficient Triton implementation.
 | 
			
		||||
    
 | 
			
		||||
    Args:
 | 
			
		||||
        packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
 | 
			
		||||
        lengths: [B] - sequence lengths for each batch
 | 
			
		||||
        block_t: block size for time dimension
 | 
			
		||||
        block_d: block size for feature dimension
 | 
			
		||||
        
 | 
			
		||||
    Returns:
 | 
			
		||||
        unpacked_tensor: [N, ...] where N = sum(lengths)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Handle multi-dimensional input by reshaping to (B, Lmax, -1)
 | 
			
		||||
    original_shape = packed_tensor.shape
 | 
			
		||||
    if len(original_shape) > 3:
 | 
			
		||||
        B, Lmax = original_shape[:2]
 | 
			
		||||
        packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
 | 
			
		||||
        D = packed_reshaped.shape[2]
 | 
			
		||||
    else:
 | 
			
		||||
        B, Lmax, D = packed_tensor.shape
 | 
			
		||||
        packed_reshaped = packed_tensor
 | 
			
		||||
 | 
			
		||||
    # Calculate total number of elements
 | 
			
		||||
    N = int(lengths.sum().item())
 | 
			
		||||
 | 
			
		||||
    out = torch.empty((N, D),
 | 
			
		||||
                      device=packed_tensor.device,
 | 
			
		||||
                      dtype=packed_tensor.dtype)
 | 
			
		||||
 | 
			
		||||
    grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
 | 
			
		||||
    _unpack_seq_triton_kernel[grid](packed_reshaped,
 | 
			
		||||
                                    out,
 | 
			
		||||
                                    lengths.int(),
 | 
			
		||||
                                    B,
 | 
			
		||||
                                    Lmax,
 | 
			
		||||
                                    D,
 | 
			
		||||
                                    BLOCK_T=block_t,
 | 
			
		||||
                                    BLOCK_D=block_d,
 | 
			
		||||
                                    num_warps=4,
 | 
			
		||||
                                    num_stages=2)
 | 
			
		||||
 | 
			
		||||
    # Reshape output back to original dimensions (except first dimension)
 | 
			
		||||
    if len(original_shape) > 3:
 | 
			
		||||
        output_shape = (N, ) + original_shape[2:]
 | 
			
		||||
        out = out.reshape(output_shape)
 | 
			
		||||
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,15 @@ if current_platform.is_cuda():
 | 
			
		||||
else:
 | 
			
		||||
    _flashmla_C_AVAILABLE = False
 | 
			
		||||
 | 
			
		||||
if current_platform.is_cuda():
 | 
			
		||||
    try:
 | 
			
		||||
        import vllm._flashmla_extension_C  # noqa: F401
 | 
			
		||||
        _flashmla_extension_C_AVAILABLE = True
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        _flashmla_extension_C_AVAILABLE = False
 | 
			
		||||
else:
 | 
			
		||||
    _flashmla_extension_C_AVAILABLE = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
 | 
			
		||||
    """
 | 
			
		||||
@ -37,24 +46,34 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_mla_metadata(
 | 
			
		||||
    cache_seqlens: torch.Tensor,
 | 
			
		||||
    num_heads_per_head_k: int,
 | 
			
		||||
    num_heads_k: int,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        cache_seqlens: torch.Tensor,
 | 
			
		||||
        num_q_tokens_per_head_k: int,
 | 
			
		||||
        num_heads_k: int,
 | 
			
		||||
        num_heads_q: Optional[int] = None,
 | 
			
		||||
        is_fp8_kvcache: bool = False,
 | 
			
		||||
        topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments:
 | 
			
		||||
        cache_seqlens: (batch_size), dtype torch.int32.
 | 
			
		||||
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
 | 
			
		||||
        num_heads_k: num_heads_k.
 | 
			
		||||
    - cache_seqlens: (batch_size), dtype torch.int32.
 | 
			
		||||
    - num_q_tokens_per_head_k: 
 | 
			
		||||
            Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
 | 
			
		||||
    - num_heads_k: The number of k heads.
 | 
			
		||||
    - num_heads_q: 
 | 
			
		||||
            The number of q heads. 
 | 
			
		||||
            This argument is optional when sparse attention is not enabled
 | 
			
		||||
    - is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
 | 
			
		||||
    - topk: If not None, sparse attention will be enabled, 
 | 
			
		||||
            and only tokens in the `indices` array 
 | 
			
		||||
            passed to `flash_mla_with_kvcache_sm90` will be attended to.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
 | 
			
		||||
                                 dtype torch.int32.
 | 
			
		||||
        num_splits: (batch_size + 1), dtype torch.int32.
 | 
			
		||||
    Returns:
 | 
			
		||||
    - tile_scheduler_metadata: 
 | 
			
		||||
            (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
 | 
			
		||||
    - num_splits: (batch_size + 1), dtype torch.int32.
 | 
			
		||||
    """
 | 
			
		||||
    return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
 | 
			
		||||
                                                  num_heads_per_head_k,
 | 
			
		||||
                                                  num_heads_k)
 | 
			
		||||
    return torch.ops._flashmla_C.get_mla_decoding_metadata(
 | 
			
		||||
        cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
 | 
			
		||||
        is_fp8_kvcache, topk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def flash_mla_with_kvcache(
 | 
			
		||||
@ -69,45 +88,95 @@ def flash_mla_with_kvcache(
 | 
			
		||||
    causal: bool = False,
 | 
			
		||||
    descale_q: Optional[torch.Tensor] = None,
 | 
			
		||||
    descale_k: Optional[torch.Tensor] = None,
 | 
			
		||||
    is_fp8_kvcache: bool = False,
 | 
			
		||||
    indices: Optional[torch.Tensor] = None,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Arguments:
 | 
			
		||||
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
 | 
			
		||||
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
 | 
			
		||||
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
 | 
			
		||||
        cache_seqlens: (batch_size), torch.int32.
 | 
			
		||||
        head_dim_v: Head_dim of v.
 | 
			
		||||
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), 
 | 
			
		||||
                                 torch.int32, return by get_mla_metadata.
 | 
			
		||||
        num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
 | 
			
		||||
        softmax_scale: float. The scaling of QK^T before applying softmax. 
 | 
			
		||||
                       Default to 1 / sqrt(head_dim).
 | 
			
		||||
        causal: bool. Whether to apply causal attention mask.
 | 
			
		||||
        descale_q: (batch_size), torch.float32. Descaling factors for Q.
 | 
			
		||||
        descale_k: (batch_size), torch.float32. Descaling factors for K.
 | 
			
		||||
    - q: (batch_size, seq_len_q, num_heads_q, head_dim).
 | 
			
		||||
    - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
 | 
			
		||||
    - block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
 | 
			
		||||
    - cache_seqlens: (batch_size), torch.int32.
 | 
			
		||||
    - head_dim_v: Head dimension of v.
 | 
			
		||||
    - tile_scheduler_metadata: 
 | 
			
		||||
        (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, 
 | 
			
		||||
        returned by get_mla_metadata.
 | 
			
		||||
    - num_splits: 
 | 
			
		||||
        (batch_size + 1), torch.int32, returned by get_mla_metadata.
 | 
			
		||||
    - softmax_scale: float. 
 | 
			
		||||
        The scale of QK^T before applying softmax. 
 | 
			
		||||
        Default to 1 / sqrt(head_dim).
 | 
			
		||||
    - causal: bool. Whether to apply causal attention mask.
 | 
			
		||||
    - descale_q: (batch_size), 
 | 
			
		||||
        torch.float32. Descaling factors for Q, used for fp8 quantization.
 | 
			
		||||
    - descale_k: (batch_size), 
 | 
			
		||||
        torch.float32. Descaling factors for K, used for fp8 quantization.
 | 
			
		||||
    - is_fp8_kvcache: bool. 
 | 
			
		||||
        Whether the k_cache and v_cache are in fp8 format. 
 | 
			
		||||
        For the format of FP8 KV cache, please refer to README.md
 | 
			
		||||
    - indices: (batch_size, seq_len_q, topk), torch.int32. 
 | 
			
		||||
        If not None, sparse attention will be enabled, 
 | 
			
		||||
        and only tokens in the `indices` array will be attended to. 
 | 
			
		||||
        Invalid indices should be set to -1 or numbers >= total_seq_len_kv. 
 | 
			
		||||
        For details about how to set up `indices`, please refer to README.md.
 | 
			
		||||
 | 
			
		||||
    Return:
 | 
			
		||||
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
 | 
			
		||||
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
 | 
			
		||||
    Returns:
 | 
			
		||||
    - out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
 | 
			
		||||
    - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
 | 
			
		||||
    """
 | 
			
		||||
    if softmax_scale is None:
 | 
			
		||||
        softmax_scale = q.shape[-1]**(-0.5)
 | 
			
		||||
    out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
 | 
			
		||||
        q,
 | 
			
		||||
        k_cache,
 | 
			
		||||
        head_dim_v,
 | 
			
		||||
        cache_seqlens,
 | 
			
		||||
        block_table,
 | 
			
		||||
        softmax_scale,
 | 
			
		||||
        causal,
 | 
			
		||||
        tile_scheduler_metadata,
 | 
			
		||||
        num_splits,
 | 
			
		||||
        descale_q,
 | 
			
		||||
        descale_k,
 | 
			
		||||
    )
 | 
			
		||||
    if indices is not None:
 | 
			
		||||
        # NOTE (zyongye): sparse attention is also causal
 | 
			
		||||
        # since it only attend to the tokens before
 | 
			
		||||
        # but here `causal` should not be specified
 | 
			
		||||
        assert not causal, \
 | 
			
		||||
            "causal must be `false` if sparse attention is enabled."
 | 
			
		||||
    assert (descale_q is None) == (
 | 
			
		||||
        descale_k is None
 | 
			
		||||
    ), "descale_q and descale_k should be both None or both not None"
 | 
			
		||||
 | 
			
		||||
    # Note(hc): need revisit when we support DCP with decode query_len > 1.
 | 
			
		||||
    return out.squeeze(1), softmax_lse.squeeze(-1)
 | 
			
		||||
    if indices is None and q.element_size() == 1:
 | 
			
		||||
        out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
 | 
			
		||||
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
 | 
			
		||||
            causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
 | 
			
		||||
    else:
 | 
			
		||||
        out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
 | 
			
		||||
            q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
 | 
			
		||||
            causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
 | 
			
		||||
            indices)
 | 
			
		||||
    return out, softmax_lse
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def flash_mla_sparse_prefill(
 | 
			
		||||
    q: torch.Tensor,
 | 
			
		||||
    kv: torch.Tensor,
 | 
			
		||||
    indices: torch.Tensor,
 | 
			
		||||
    sm_scale: float,
 | 
			
		||||
    d_v: int = 512,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Sparse attention prefill kernel
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
    - q: [s_q, h_q, d_qk], bfloat16
 | 
			
		||||
    - kv: [s_kv, h_kv, d_qk], bfloat16
 | 
			
		||||
    - indices: [s_q, h_kv, topk], int32. 
 | 
			
		||||
        Invalid indices should be set to -1 or numbers >= s_kv
 | 
			
		||||
    - sm_scale: float
 | 
			
		||||
    - d_v: The dimension of value vectors. Can only be 512
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
    - (output, max_logits, lse)
 | 
			
		||||
        About the definition of output, 
 | 
			
		||||
        max_logits and lse, please refer to README.md
 | 
			
		||||
    - output: [s_q, h_q, d_v], bfloat16
 | 
			
		||||
    - max_logits:  [s_q, h_q], float
 | 
			
		||||
    - lse: [s_q, h_q], float, 2-based log-sum-exp
 | 
			
		||||
    """
 | 
			
		||||
    results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
 | 
			
		||||
                                                       sm_scale, d_v)
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,7 @@ class PagedAttention:
 | 
			
		||||
        block_size: int,
 | 
			
		||||
        num_kv_heads: int,
 | 
			
		||||
        head_size: int,
 | 
			
		||||
        cache_dtype_str: str = "auto",
 | 
			
		||||
    ) -> Tuple[int, ...]:
 | 
			
		||||
        return (2, num_blocks, block_size * num_kv_heads * head_size)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -144,6 +144,7 @@ def get_attn_backend(
 | 
			
		||||
    block_size: int,
 | 
			
		||||
    use_mla: bool = False,
 | 
			
		||||
    has_sink: bool = False,
 | 
			
		||||
    use_sparse: bool = False,
 | 
			
		||||
) -> type[AttentionBackend]:
 | 
			
		||||
    """Selects which attention backend to use and lazily imports it."""
 | 
			
		||||
    # Accessing envs.* behind an @lru_cache decorator can cause the wrong
 | 
			
		||||
@ -158,6 +159,7 @@ def get_attn_backend(
 | 
			
		||||
        use_v1=envs.VLLM_USE_V1,
 | 
			
		||||
        use_mla=use_mla,
 | 
			
		||||
        has_sink=has_sink,
 | 
			
		||||
        use_sparse=use_sparse,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -170,6 +172,7 @@ def _cached_get_attn_backend(
 | 
			
		||||
    use_v1: bool = False,
 | 
			
		||||
    use_mla: bool = False,
 | 
			
		||||
    has_sink: bool = False,
 | 
			
		||||
    use_sparse: bool = False,
 | 
			
		||||
) -> type[AttentionBackend]:
 | 
			
		||||
 | 
			
		||||
    # Check whether a particular choice of backend was
 | 
			
		||||
@ -203,7 +206,7 @@ def _cached_get_attn_backend(
 | 
			
		||||
    # get device-specific attn_backend
 | 
			
		||||
    attention_cls = current_platform.get_attn_backend_cls(
 | 
			
		||||
        selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
 | 
			
		||||
        use_mla, has_sink)
 | 
			
		||||
        use_mla, has_sink, use_sparse)
 | 
			
		||||
    if not attention_cls:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Invalid attention backend for {current_platform.device_name}")
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,8 @@ else:
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
 | 
			
		||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
 | 
			
		||||
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2",
 | 
			
		||||
                     "fp8_inc"]
 | 
			
		||||
MambaDType = Literal["auto", "float32"]
 | 
			
		||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
 | 
			
		||||
 | 
			
		||||
@ -52,7 +53,11 @@ class CacheConfig:
 | 
			
		||||
    cache_dtype: CacheDType = "auto"
 | 
			
		||||
    """Data type for kv cache storage. If "auto", will use model data type.
 | 
			
		||||
    CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
 | 
			
		||||
    fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
 | 
			
		||||
    fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
 | 
			
		||||
    Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
 | 
			
		||||
    bfloat16 instead, this is an invalid option for models that do not default
 | 
			
		||||
    to fp8.
 | 
			
		||||
    """
 | 
			
		||||
    is_attention_free: bool = False
 | 
			
		||||
    """Whether the model is attention-free. This is primarily set in
 | 
			
		||||
    `ModelConfig` and that value should be manually duplicated here."""
 | 
			
		||||
@ -171,11 +176,12 @@ class CacheConfig:
 | 
			
		||||
        if self.cache_dtype == "auto":
 | 
			
		||||
            pass
 | 
			
		||||
        elif self.cache_dtype in get_args(CacheDType):
 | 
			
		||||
            logger.info(
 | 
			
		||||
                "Using fp8 data type to store kv cache. It reduces the GPU "
 | 
			
		||||
                "memory footprint and boosts the performance. "
 | 
			
		||||
                "Meanwhile, it may cause accuracy drop without a proper "
 | 
			
		||||
                "scaling factor.")
 | 
			
		||||
            if self.cache_dtype.startswith("fp8"):
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    "Using fp8 data type to store kv cache. It reduces the GPU "
 | 
			
		||||
                    "memory footprint and boosts the performance. "
 | 
			
		||||
                    "Meanwhile, it may cause accuracy drop without a proper "
 | 
			
		||||
                    "scaling factor.")
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -360,6 +360,7 @@ class CompilationConfig:
 | 
			
		||||
        "vllm.linear_attention",
 | 
			
		||||
        "vllm.plamo2_mamba_mixer",
 | 
			
		||||
        "vllm.gdn_attention",
 | 
			
		||||
        "vllm.sparse_attn_indexer",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    def compute_hash(self) -> str:
 | 
			
		||||
 | 
			
		||||
@ -137,6 +137,9 @@ class ModelConfig:
 | 
			
		||||
    """Allowing API requests to read local images or videos from directories
 | 
			
		||||
    specified by the server file system. This is a security risk. Should only
 | 
			
		||||
    be enabled in trusted environments."""
 | 
			
		||||
    allowed_media_domains: Optional[list[str]] = None
 | 
			
		||||
    """If set, only media URLs that belong to this domain can be used for 
 | 
			
		||||
    multi-modal inputs. """
 | 
			
		||||
    revision: Optional[str] = None
 | 
			
		||||
    """The specific model version to use. It can be a branch name, a tag name,
 | 
			
		||||
    or a commit id. If unspecified, will use the default version."""
 | 
			
		||||
@ -1074,14 +1077,14 @@ class ModelConfig:
 | 
			
		||||
        if not hasattr(self.hf_text_config, "model_type"):
 | 
			
		||||
            return False
 | 
			
		||||
        elif self.hf_text_config.model_type in \
 | 
			
		||||
            ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
 | 
			
		||||
            ('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp',
 | 
			
		||||
              'kimi_k2', 'longcat_flash'):
 | 
			
		||||
            return self.hf_text_config.kv_lora_rank is not None
 | 
			
		||||
        elif self.hf_text_config.model_type == 'eagle':
 | 
			
		||||
            # if the model is an EAGLE module, check for the
 | 
			
		||||
            # underlying architecture
 | 
			
		||||
            return self.hf_text_config.model.model_type in \
 | 
			
		||||
                    ('deepseek_v2', 'deepseek_v3') \
 | 
			
		||||
                    ('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
 | 
			
		||||
                and self.hf_text_config.kv_lora_rank is not None
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -279,6 +279,24 @@ class ParallelConfig:
 | 
			
		||||
        assert last_exc is not None
 | 
			
		||||
        raise last_exc
 | 
			
		||||
 | 
			
		||||
    # The all_reduce at the end of attention (during o_proj) means that
 | 
			
		||||
    # inputs are replicated across each rank of the tensor parallel group.
 | 
			
		||||
    # If using expert-parallelism with DeepEP All2All ops, replicated
 | 
			
		||||
    # tokens results in useless duplicate computation and communication.
 | 
			
		||||
    #
 | 
			
		||||
    # In this case, ensure the input to the experts is sequence parallel
 | 
			
		||||
    # to avoid the excess work.
 | 
			
		||||
    #
 | 
			
		||||
    # Not needed for pplx-kernels as it can handle duplicate input tokens.
 | 
			
		||||
    @property
 | 
			
		||||
    def use_sequence_parallel_moe(self) -> bool:
 | 
			
		||||
        return (envs.VLLM_ALL2ALL_BACKEND
 | 
			
		||||
                in ("allgather_reducescatter", "naive",
 | 
			
		||||
                    "deepep_high_throughput", "deepep_low_latency")
 | 
			
		||||
                and self.enable_expert_parallel
 | 
			
		||||
                and self.tensor_parallel_size > 1
 | 
			
		||||
                and self.data_parallel_size > 1)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def has_unfinished_dp(dp_group: ProcessGroup,
 | 
			
		||||
                          has_unfinished: bool) -> bool:
 | 
			
		||||
 | 
			
		||||
@ -32,14 +32,17 @@ logger = init_logger(__name__)
 | 
			
		||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
 | 
			
		||||
                            "mlp_speculator", "draft_model", "deepseek_mtp",
 | 
			
		||||
                            "ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
 | 
			
		||||
                            "longcat_flash_mtp"]
 | 
			
		||||
                            "longcat_flash_mtp", "mtp"]
 | 
			
		||||
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
 | 
			
		||||
                   "qwen3_next_mtp", "longcat_flash_mtp")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@config
 | 
			
		||||
@dataclass
 | 
			
		||||
class SpeculativeConfig:
 | 
			
		||||
    """Configuration for speculative decoding."""
 | 
			
		||||
 | 
			
		||||
    enforce_eager: Optional[bool] = None
 | 
			
		||||
    """Override the default enforce_eager from model_config"""
 | 
			
		||||
    # General speculative decoding control
 | 
			
		||||
    num_speculative_tokens: SkipValidation[int] = None  # type: ignore
 | 
			
		||||
    """The number of speculative tokens, if provided. It will default to the
 | 
			
		||||
@ -143,7 +146,7 @@ class SpeculativeConfig:
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
 | 
			
		||||
        if hf_config.model_type == "deepseek_v3":
 | 
			
		||||
        if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
 | 
			
		||||
            hf_config.model_type = "deepseek_mtp"
 | 
			
		||||
        if hf_config.model_type == "deepseek_mtp":
 | 
			
		||||
            n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
 | 
			
		||||
@ -207,11 +210,21 @@ class SpeculativeConfig:
 | 
			
		||||
        # can not be detected, it will be considered as the "draft_model" by
 | 
			
		||||
        # default.
 | 
			
		||||
 | 
			
		||||
        if self.method in MTP_MODEL_TYPES:
 | 
			
		||||
            logger.warning("method `%s` is deprecated and replaced with mtp.",
 | 
			
		||||
                           self.method)
 | 
			
		||||
            self.method = "mtp"
 | 
			
		||||
 | 
			
		||||
        if self.model is None and self.num_speculative_tokens is not None:
 | 
			
		||||
            # TODO(Shangming): Refactor mtp configuration logic when supporting
 | 
			
		||||
            if (self.target_model_config
 | 
			
		||||
                    and self.target_model_config.hf_text_config.model_type
 | 
			
		||||
                    in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
 | 
			
		||||
            if self.method == "mtp":
 | 
			
		||||
                assert (
 | 
			
		||||
                    self.target_model_config
 | 
			
		||||
                    is not None), "target_model_config must be present for mtp"
 | 
			
		||||
                if self.target_model_config.hf_text_config.model_type \
 | 
			
		||||
                    == "deepseek_v32":
 | 
			
		||||
                    # FIXME(luccafong): cudgraph with v32 MTP is not supported,
 | 
			
		||||
                    # remove this when the issue is fixed.
 | 
			
		||||
                    self.enforce_eager = True
 | 
			
		||||
                # use the draft model from the same model:
 | 
			
		||||
                self.model = self.target_model_config.model
 | 
			
		||||
                # Align the quantization of draft model for cases such as
 | 
			
		||||
@ -281,6 +294,8 @@ class SpeculativeConfig:
 | 
			
		||||
                    trust_remote_code,
 | 
			
		||||
                    allowed_local_media_path=self.target_model_config.
 | 
			
		||||
                    allowed_local_media_path,
 | 
			
		||||
                    allowed_media_domains=self.target_model_config.
 | 
			
		||||
                    allowed_media_domains,
 | 
			
		||||
                    dtype=self.target_model_config.dtype,
 | 
			
		||||
                    seed=self.target_model_config.seed,
 | 
			
		||||
                    revision=self.revision,
 | 
			
		||||
@ -312,31 +327,13 @@ class SpeculativeConfig:
 | 
			
		||||
                      "mlp_speculator"):
 | 
			
		||||
                    self.method = "mlp_speculator"
 | 
			
		||||
                elif (self.draft_model_config.hf_config.model_type
 | 
			
		||||
                      in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
 | 
			
		||||
                    self.method = "deepseek_mtp"
 | 
			
		||||
                      in MTP_MODEL_TYPES):
 | 
			
		||||
                    self.method = "mtp"
 | 
			
		||||
                    if self.num_speculative_tokens > 1:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
                                "All Deepseek MTP models only have " \
 | 
			
		||||
                                "one layer. Might need some code changes " \
 | 
			
		||||
                                "to support multiple layers."
 | 
			
		||||
                            )
 | 
			
		||||
                elif (self.draft_model_config.hf_config.model_type ==
 | 
			
		||||
                      "ernie_mtp"):
 | 
			
		||||
                    self.method = "ernie_mtp"
 | 
			
		||||
                    if self.num_speculative_tokens > 1:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
                                "All Ernie MTP models only have " \
 | 
			
		||||
                                "one layer. Might need some code changes " \
 | 
			
		||||
                                "to support multiple layers."
 | 
			
		||||
                            )
 | 
			
		||||
                elif (self.draft_model_config.hf_config.model_type ==
 | 
			
		||||
                      "qwen3_next_mtp"):
 | 
			
		||||
                    self.method = "qwen3_next_mtp"
 | 
			
		||||
                    if self.num_speculative_tokens > 1:
 | 
			
		||||
                        logger.warning(
 | 
			
		||||
                                "All Qwen3Next MTP models only have " \
 | 
			
		||||
                                "one layer. Might need some code changes " \
 | 
			
		||||
                                "to support multiple layers."
 | 
			
		||||
                                "Enabling num_speculative_tokens > 1 will run" \
 | 
			
		||||
                                "multiple times of forward on same MTP layer" \
 | 
			
		||||
                                ",which may result in lower acceptance rate" \
 | 
			
		||||
                            )
 | 
			
		||||
                elif (self.draft_model_config.hf_config.model_type
 | 
			
		||||
                      in ("longcat_flash_mtp")):
 | 
			
		||||
@ -353,7 +350,7 @@ class SpeculativeConfig:
 | 
			
		||||
                        "Speculative decoding with draft model is not "
 | 
			
		||||
                        "supported yet. Please consider using other "
 | 
			
		||||
                        "speculative decoding methods such as ngram, medusa, "
 | 
			
		||||
                        "eagle, or deepseek_mtp.")
 | 
			
		||||
                        "eagle, or mtp.")
 | 
			
		||||
 | 
			
		||||
                # Replace hf_config for EAGLE draft_model
 | 
			
		||||
                if self.method in ("eagle", "eagle3"):
 | 
			
		||||
@ -562,8 +559,7 @@ class SpeculativeConfig:
 | 
			
		||||
        return self.num_speculative_tokens
 | 
			
		||||
 | 
			
		||||
    def use_eagle(self) -> bool:
 | 
			
		||||
        return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
 | 
			
		||||
                               "qwen3_next_mtp", "longcat_flash_mtp")
 | 
			
		||||
        return self.method in ("eagle", "eagle3", "mtp")
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        method = self.method
 | 
			
		||||
 | 
			
		||||
@ -54,6 +54,7 @@ class HTTPConnection:
 | 
			
		||||
        stream: bool = False,
 | 
			
		||||
        timeout: Optional[float] = None,
 | 
			
		||||
        extra_headers: Optional[Mapping[str, str]] = None,
 | 
			
		||||
        allow_redirects: bool = True,
 | 
			
		||||
    ):
 | 
			
		||||
        self._validate_http_url(url)
 | 
			
		||||
 | 
			
		||||
@ -63,7 +64,8 @@ class HTTPConnection:
 | 
			
		||||
        return client.get(url,
 | 
			
		||||
                          headers=self._headers(**extra_headers),
 | 
			
		||||
                          stream=stream,
 | 
			
		||||
                          timeout=timeout)
 | 
			
		||||
                          timeout=timeout,
 | 
			
		||||
                          allow_redirects=allow_redirects)
 | 
			
		||||
 | 
			
		||||
    async def get_async_response(
 | 
			
		||||
        self,
 | 
			
		||||
@ -71,6 +73,7 @@ class HTTPConnection:
 | 
			
		||||
        *,
 | 
			
		||||
        timeout: Optional[float] = None,
 | 
			
		||||
        extra_headers: Optional[Mapping[str, str]] = None,
 | 
			
		||||
        allow_redirects: bool = True,
 | 
			
		||||
    ):
 | 
			
		||||
        self._validate_http_url(url)
 | 
			
		||||
 | 
			
		||||
@ -79,10 +82,17 @@ class HTTPConnection:
 | 
			
		||||
 | 
			
		||||
        return client.get(url,
 | 
			
		||||
                          headers=self._headers(**extra_headers),
 | 
			
		||||
                          timeout=timeout)
 | 
			
		||||
                          timeout=timeout,
 | 
			
		||||
                          allow_redirects=allow_redirects)
 | 
			
		||||
 | 
			
		||||
    def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
 | 
			
		||||
        with self.get_response(url, timeout=timeout) as r:
 | 
			
		||||
    def get_bytes(self,
 | 
			
		||||
                  url: str,
 | 
			
		||||
                  *,
 | 
			
		||||
                  timeout: Optional[float] = None,
 | 
			
		||||
                  allow_redirects: bool = True) -> bytes:
 | 
			
		||||
        with self.get_response(url,
 | 
			
		||||
                               timeout=timeout,
 | 
			
		||||
                               allow_redirects=allow_redirects) as r:
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
            return r.content
 | 
			
		||||
@ -92,8 +102,10 @@ class HTTPConnection:
 | 
			
		||||
        url: str,
 | 
			
		||||
        *,
 | 
			
		||||
        timeout: Optional[float] = None,
 | 
			
		||||
        allow_redirects: bool = True,
 | 
			
		||||
    ) -> bytes:
 | 
			
		||||
        async with await self.get_async_response(url, timeout=timeout) as r:
 | 
			
		||||
        async with await self.get_async_response(
 | 
			
		||||
                url, timeout=timeout, allow_redirects=allow_redirects) as r:
 | 
			
		||||
            r.raise_for_status()
 | 
			
		||||
 | 
			
		||||
            return await r.read()
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
 | 
			
		||||
import vllm.envs as envs
 | 
			
		||||
from vllm.distributed import get_dp_group
 | 
			
		||||
from vllm.distributed import get_dp_group, get_ep_group
 | 
			
		||||
from vllm.forward_context import get_forward_context
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.utils import has_deep_ep, has_pplx
 | 
			
		||||
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
 | 
			
		||||
        super().__init__(cpu_group)
 | 
			
		||||
 | 
			
		||||
    def naive_multicast(self, x: torch.Tensor,
 | 
			
		||||
                        cu_tokens_across_dp_cpu: torch.Tensor):
 | 
			
		||||
                        cu_tokens_across_sp_cpu: torch.Tensor,
 | 
			
		||||
                        is_sequence_parallel: bool) -> torch.Tensor:
 | 
			
		||||
        assert (len(x.shape) == 2)
 | 
			
		||||
        buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
 | 
			
		||||
        buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
 | 
			
		||||
                             device=x.device,
 | 
			
		||||
                             dtype=x.dtype)
 | 
			
		||||
 | 
			
		||||
        start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
 | 
			
		||||
            self.dp_rank - 1]
 | 
			
		||||
        end = cu_tokens_across_dp_cpu[self.dp_rank]
 | 
			
		||||
        rank = self.rank if is_sequence_parallel else self.dp_rank
 | 
			
		||||
        world_size = (self.world_size
 | 
			
		||||
                      if is_sequence_parallel else self.dp_world_size)
 | 
			
		||||
 | 
			
		||||
        start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
 | 
			
		||||
        end = cu_tokens_across_sp_cpu[rank]
 | 
			
		||||
        buffer[start:end, :].copy_(x)
 | 
			
		||||
        for idx in range(self.dp_world_size):
 | 
			
		||||
            start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
 | 
			
		||||
            end = cu_tokens_across_dp_cpu[idx]
 | 
			
		||||
            self.dp_group.broadcast(buffer[start:end, :], idx)
 | 
			
		||||
        for idx in range(world_size):
 | 
			
		||||
            start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
 | 
			
		||||
            end = cu_tokens_across_sp_cpu[idx]
 | 
			
		||||
            get_ep_group().broadcast(buffer[start:end, :], idx)
 | 
			
		||||
 | 
			
		||||
        return buffer
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, hidden_states: torch.Tensor,
 | 
			
		||||
                 router_logits: torch.Tensor):
 | 
			
		||||
        sizes = get_forward_context(
 | 
			
		||||
        ).dp_metadata.get_chunk_sizes_across_dp_rank()
 | 
			
		||||
        hidden_states, router_logits = get_dp_group().all_gatherv(
 | 
			
		||||
            [hidden_states, router_logits],
 | 
			
		||||
            dim=0,
 | 
			
		||||
            sizes=sizes,
 | 
			
		||||
        )
 | 
			
		||||
    def dispatch(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
 | 
			
		||||
        dp_metadata = get_forward_context().dp_metadata
 | 
			
		||||
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.naive_multicast(hidden_states,
 | 
			
		||||
                                             cu_tokens_across_sp_cpu,
 | 
			
		||||
                                             is_sequence_parallel)
 | 
			
		||||
        router_logits = self.naive_multicast(router_logits,
 | 
			
		||||
                                             cu_tokens_across_sp_cpu,
 | 
			
		||||
                                             is_sequence_parallel)
 | 
			
		||||
        return hidden_states, router_logits
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        sizes = get_forward_context(
 | 
			
		||||
        ).dp_metadata.get_chunk_sizes_across_dp_rank()
 | 
			
		||||
        hidden_states = get_dp_group().reduce_scatterv(hidden_states,
 | 
			
		||||
                                                       dim=0,
 | 
			
		||||
                                                       sizes=sizes)
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
        ep_rank = self.rank if is_sequence_parallel else self.dp_rank
 | 
			
		||||
 | 
			
		||||
        dp_metadata = get_forward_context().dp_metadata
 | 
			
		||||
        sp_size = self.tp_group.world_size if is_sequence_parallel else 1
 | 
			
		||||
        cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
 | 
			
		||||
 | 
			
		||||
        start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
 | 
			
		||||
        end = cu_tokens_across_sp_cpu[ep_rank]
 | 
			
		||||
 | 
			
		||||
        all_hidden_states = get_ep_group().all_reduce(hidden_states)
 | 
			
		||||
        hidden_states = all_hidden_states[start:end, :]
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
    def destroy(self):
 | 
			
		||||
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
 | 
			
		||||
    def __init__(self, cpu_group):
 | 
			
		||||
        super().__init__(cpu_group)
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, hidden_states: torch.Tensor,
 | 
			
		||||
                 router_logits: torch.Tensor):
 | 
			
		||||
    def dispatch(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Gather hidden_states and router_logits from all dp ranks.
 | 
			
		||||
        """
 | 
			
		||||
        sizes = get_forward_context(
 | 
			
		||||
        ).dp_metadata.get_chunk_sizes_across_dp_rank()
 | 
			
		||||
        hidden_states, router_logits = get_dp_group().all_gatherv(
 | 
			
		||||
 | 
			
		||||
        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
 | 
			
		||||
        assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
 | 
			
		||||
        hidden_states, router_logits = dist_group.all_gatherv(
 | 
			
		||||
            [hidden_states, router_logits],
 | 
			
		||||
            dim=0,
 | 
			
		||||
            sizes=sizes,
 | 
			
		||||
        )
 | 
			
		||||
        return hidden_states, router_logits
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Reduce-scatter hidden_states across all dp ranks.
 | 
			
		||||
        """
 | 
			
		||||
        sizes = get_forward_context(
 | 
			
		||||
        ).dp_metadata.get_chunk_sizes_across_dp_rank()
 | 
			
		||||
        hidden_states = get_dp_group().reduce_scatterv(hidden_states,
 | 
			
		||||
                                                       dim=0,
 | 
			
		||||
                                                       sizes=sizes)
 | 
			
		||||
 | 
			
		||||
        dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
 | 
			
		||||
        hidden_states = dist_group.reduce_scatterv(hidden_states,
 | 
			
		||||
                                                   dim=0,
 | 
			
		||||
                                                   sizes=sizes)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
    def destroy(self):
 | 
			
		||||
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
 | 
			
		||||
            kwargs, pplx.AllToAll.internode
 | 
			
		||||
            if self.internode else pplx.AllToAll.intranode)
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, hidden_states: torch.Tensor,
 | 
			
		||||
                 router_logits: torch.Tensor):
 | 
			
		||||
    def dispatch(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def destroy(self):
 | 
			
		||||
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
 | 
			
		||||
    def get_handle(self, kwargs):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, hidden_states: torch.Tensor,
 | 
			
		||||
                 router_logits: torch.Tensor):
 | 
			
		||||
    def dispatch(
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def destroy(self):
 | 
			
		||||
 | 
			
		||||
@ -28,6 +28,8 @@ class Cache:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class All2AllManagerBase:
 | 
			
		||||
    rank: int
 | 
			
		||||
    world_size: int
 | 
			
		||||
 | 
			
		||||
    def __init__(self, cpu_group):
 | 
			
		||||
        self.cpu_group = cpu_group
 | 
			
		||||
@ -40,6 +42,7 @@ class All2AllManagerBase:
 | 
			
		||||
        # all2all lives in ep group, which is merged from dp and tp group
 | 
			
		||||
        self.dp_group = get_dp_group()
 | 
			
		||||
        self.tp_group = get_tp_group()
 | 
			
		||||
 | 
			
		||||
        # no self.ep_group since self.ep_group is still in construction
 | 
			
		||||
        # when we create this object
 | 
			
		||||
        self.dp_rank = self.dp_group.rank_in_group
 | 
			
		||||
@ -60,17 +63,21 @@ class All2AllManagerBase:
 | 
			
		||||
        # and reuse it for the same config.
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def dispatch(self,
 | 
			
		||||
                 hidden_states: torch.Tensor,
 | 
			
		||||
                 router_logits: torch.Tensor,
 | 
			
		||||
                 is_sequence_parallel: bool = False):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def set_num_sms(self, num_sms: int):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def max_sms_used(self) -> Optional[int]:
 | 
			
		||||
        return None  # None means it could use the whole GPU
 | 
			
		||||
 | 
			
		||||
    def dispatch(self, hidden_states: torch.Tensor,
 | 
			
		||||
                 router_logits: torch.Tensor):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def destroy(self):
 | 
			
		||||
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
 | 
			
		||||
            module.quant_method.init_prepare_finalize(module)
 | 
			
		||||
 | 
			
		||||
    def dispatch(
 | 
			
		||||
            self, hidden_states: torch.Tensor,
 | 
			
		||||
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Dispatch the hidden states and router logits to the appropriate device.
 | 
			
		||||
        This is a no-op in the base class.
 | 
			
		||||
        """
 | 
			
		||||
        return hidden_states, router_logits
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Combine the hidden states and router logits from the appropriate device.
 | 
			
		||||
        This is a no-op in the base class.
 | 
			
		||||
 | 
			
		||||
@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
 | 
			
		||||
            use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
 | 
			
		||||
            use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
 | 
			
		||||
 | 
			
		||||
        # ep does not use pynccl
 | 
			
		||||
        use_pynccl = "ep" not in unique_name
 | 
			
		||||
 | 
			
		||||
        self.use_pynccl = use_pynccl
 | 
			
		||||
        self.use_custom_allreduce = use_custom_allreduce
 | 
			
		||||
        self.use_torch_symm_mem = use_torch_symm_mem
 | 
			
		||||
 | 
			
		||||
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
 | 
			
		||||
            SymmMemCommunicator)
 | 
			
		||||
 | 
			
		||||
        self.pynccl_comm: Optional[PyNcclCommunicator] = None
 | 
			
		||||
        if use_pynccl and self.world_size > 1:
 | 
			
		||||
        if self.world_size > 1:
 | 
			
		||||
            self.pynccl_comm = PyNcclCommunicator(
 | 
			
		||||
                group=self.cpu_group,
 | 
			
		||||
                device=self.device,
 | 
			
		||||
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
 | 
			
		||||
        return output_list
 | 
			
		||||
 | 
			
		||||
    def dispatch(
 | 
			
		||||
            self, hidden_states: torch.Tensor,
 | 
			
		||||
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        assert self.all2all_manager is not None
 | 
			
		||||
        hidden_states, router_logits = self.all2all_manager.dispatch(
 | 
			
		||||
            hidden_states, router_logits)
 | 
			
		||||
            hidden_states, router_logits, is_sequence_parallel)
 | 
			
		||||
        return hidden_states, router_logits
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
        assert self.all2all_manager is not None
 | 
			
		||||
        hidden_states = self.all2all_manager.combine(hidden_states)
 | 
			
		||||
        hidden_states = self.all2all_manager.combine(hidden_states,
 | 
			
		||||
                                                     is_sequence_parallel)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
 | 
			
		||||
        dist.broadcast(input_, src=src, group=self.device_group)
 | 
			
		||||
 | 
			
		||||
    def dispatch(
 | 
			
		||||
            self, hidden_states: torch.Tensor,
 | 
			
		||||
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        assert self.all2all_manager is not None
 | 
			
		||||
        hidden_states, router_logits = self.all2all_manager.dispatch(
 | 
			
		||||
            hidden_states, router_logits)
 | 
			
		||||
            hidden_states, router_logits, is_sequence_parallel)
 | 
			
		||||
        return hidden_states, router_logits
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states: torch.Tensor,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
        assert self.all2all_manager is not None
 | 
			
		||||
        hidden_states = self.all2all_manager.combine(hidden_states)
 | 
			
		||||
        hidden_states = self.all2all_manager.combine(hidden_states,
 | 
			
		||||
                                                     is_sequence_parallel)
 | 
			
		||||
        return hidden_states
 | 
			
		||||
 | 
			
		||||
@ -84,7 +84,7 @@ class NixlAgentMetadata(
 | 
			
		||||
    agent_metadata: bytes
 | 
			
		||||
    kv_caches_base_addr: list[int]
 | 
			
		||||
    num_blocks: int
 | 
			
		||||
    block_len: int
 | 
			
		||||
    block_lens: list[int]
 | 
			
		||||
    attn_backend_name: str
 | 
			
		||||
    kv_cache_layout: str
 | 
			
		||||
 | 
			
		||||
@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
 | 
			
		||||
        self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
 | 
			
		||||
        self.reqs_to_save: dict[ReqId, ReqMeta] = {}
 | 
			
		||||
        self.reqs_to_send: dict[ReqId, float] = {}
 | 
			
		||||
        self.reqs_in_batch: set[ReqId] = set()
 | 
			
		||||
 | 
			
		||||
    def add_new_req(
 | 
			
		||||
        self,
 | 
			
		||||
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
 | 
			
		||||
        self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
 | 
			
		||||
        # Reqs to send and their expiration time
 | 
			
		||||
        self._reqs_need_send: dict[ReqId, float] = {}
 | 
			
		||||
        self._reqs_in_batch: set[ReqId] = set()
 | 
			
		||||
 | 
			
		||||
    def get_num_new_matched_tokens(
 | 
			
		||||
            self, request: "Request",
 | 
			
		||||
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
 | 
			
		||||
 | 
			
		||||
        if not params:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        if params.get("do_remote_decode"):
 | 
			
		||||
            self._reqs_in_batch.add(request.request_id)
 | 
			
		||||
        if self.use_host_buffer and params.get("do_remote_decode"):
 | 
			
		||||
            # NOTE: when accelerator is not directly supported by Nixl,
 | 
			
		||||
            # prefilled blocks need to be saved to host memory before transfer.
 | 
			
		||||
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
 | 
			
		||||
                request_id=req_id,
 | 
			
		||||
                local_block_ids=block_ids,
 | 
			
		||||
                kv_transfer_params=req.kv_transfer_params,
 | 
			
		||||
                load_remote_cache=True,
 | 
			
		||||
                save_to_host=False,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        for req_id, (req, block_ids) in self._reqs_need_save.items():
 | 
			
		||||
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        meta.reqs_to_send = self._reqs_need_send
 | 
			
		||||
        meta.reqs_in_batch = self._reqs_in_batch
 | 
			
		||||
 | 
			
		||||
        # Clear the list once workers start the transfers
 | 
			
		||||
        self._reqs_need_recv.clear()
 | 
			
		||||
        self._reqs_need_save.clear()
 | 
			
		||||
        self._reqs_in_batch = set()
 | 
			
		||||
        self._reqs_need_send = {}
 | 
			
		||||
 | 
			
		||||
        return meta
 | 
			
		||||
@ -465,8 +474,11 @@ class NixlConnectorWorker:
 | 
			
		||||
                "backends", ["UCX"])
 | 
			
		||||
        # Agent.
 | 
			
		||||
        non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
 | 
			
		||||
        config = nixl_agent_config(backends=self.nixl_backends) if len(
 | 
			
		||||
            non_ucx_backends) > 0 and nixl_agent_config is not None else None
 | 
			
		||||
        if nixl_agent_config is None:
 | 
			
		||||
            config = None
 | 
			
		||||
        else:
 | 
			
		||||
            config = nixl_agent_config(backends=self.nixl_backends) if len(
 | 
			
		||||
                non_ucx_backends) > 0 else nixl_agent_config(num_threads=8)
 | 
			
		||||
 | 
			
		||||
        self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
 | 
			
		||||
        # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
 | 
			
		||||
@ -546,6 +558,8 @@ class NixlConnectorWorker:
 | 
			
		||||
        self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
 | 
			
		||||
        # Track the expiration time of requests that are waiting to be sent.
 | 
			
		||||
        self._reqs_to_send: dict[ReqId, float] = {}
 | 
			
		||||
        # Set of requests that have been part of a batch, regardless of status.
 | 
			
		||||
        self._reqs_to_process: set[ReqId] = set()
 | 
			
		||||
 | 
			
		||||
        # Background thread for handling new handshake requests.
 | 
			
		||||
        self._nixl_handshake_listener_t: Optional[threading.Thread] = None
 | 
			
		||||
@ -752,6 +766,9 @@ class NixlConnectorWorker:
 | 
			
		||||
        split_k_and_v = not (self.use_mla or self._use_pallas
 | 
			
		||||
                             or self._use_flashinfer)
 | 
			
		||||
        tensor_size_bytes = None
 | 
			
		||||
        # Enable different block lengths for different layers when MLA is used.
 | 
			
		||||
        self.block_len_per_layer = list[int]()
 | 
			
		||||
        self.slot_size_per_layer = list[int]()  # HD bytes in kv terms
 | 
			
		||||
        for layer_name, cache_or_caches in xfer_buffers.items():
 | 
			
		||||
            cache_list = cache_or_caches if split_k_and_v else [
 | 
			
		||||
                cache_or_caches
 | 
			
		||||
@ -769,10 +786,25 @@ class NixlConnectorWorker:
 | 
			
		||||
                    tensor_size_bytes = curr_tensor_size_bytes
 | 
			
		||||
                    self.num_blocks = cache.shape[0]
 | 
			
		||||
 | 
			
		||||
                assert tensor_size_bytes == curr_tensor_size_bytes, \
 | 
			
		||||
                    "All kv cache tensors must have the same size"
 | 
			
		||||
                assert cache.shape[0] == self.num_blocks, \
 | 
			
		||||
                    "All kv cache tensors must have the same number of blocks"
 | 
			
		||||
 | 
			
		||||
                self.block_len_per_layer.append(curr_tensor_size_bytes //
 | 
			
		||||
                                                self.num_blocks)
 | 
			
		||||
                self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
 | 
			
		||||
                                                self.block_size)
 | 
			
		||||
 | 
			
		||||
                if not self.use_mla:
 | 
			
		||||
                    # Different kv cache shape is not supported by HeteroTP
 | 
			
		||||
                    assert tensor_size_bytes == curr_tensor_size_bytes, \
 | 
			
		||||
                        "All kv cache tensors must have the same size"
 | 
			
		||||
                caches_data.append(
 | 
			
		||||
                    (base_addr, tensor_size_bytes, self.tp_rank, ""))
 | 
			
		||||
                    (base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
 | 
			
		||||
 | 
			
		||||
        logger.debug("Different block lengths collected: %s",
 | 
			
		||||
                     set(self.block_len_per_layer))
 | 
			
		||||
        assert len(self.block_len_per_layer) == len(seen_base_addresses)
 | 
			
		||||
        assert self.num_blocks != 0
 | 
			
		||||
 | 
			
		||||
        self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
 | 
			
		||||
        self.num_regions = len(caches_data)
 | 
			
		||||
@ -785,16 +817,12 @@ class NixlConnectorWorker:
 | 
			
		||||
        logger.debug("Done registering descs")
 | 
			
		||||
        self._registered_descs.append(descs)
 | 
			
		||||
 | 
			
		||||
        assert tensor_size_bytes is not None
 | 
			
		||||
        assert self.num_blocks != 0
 | 
			
		||||
        assert tensor_size_bytes % self.num_blocks == 0
 | 
			
		||||
        self.block_len = tensor_size_bytes // self.num_blocks
 | 
			
		||||
        self.slot_size_bytes = self.block_len // self.block_size
 | 
			
		||||
        self.device_kv_caches = kv_caches
 | 
			
		||||
        self.dst_num_blocks[self.engine_id] = self.num_blocks
 | 
			
		||||
        if self._use_flashinfer:
 | 
			
		||||
            assert self.slot_size_bytes % 2 == 0
 | 
			
		||||
            self.slot_size_bytes /= 2
 | 
			
		||||
            for i in range(len(self.slot_size_per_layer)):
 | 
			
		||||
                assert self.slot_size_per_layer[i] % 2 == 0
 | 
			
		||||
                self.slot_size_per_layer[i] //= 2
 | 
			
		||||
 | 
			
		||||
            # NOTE (NickLucche) When FlashInfer is used, memory is registered
 | 
			
		||||
            # with joint KV for each block. This minimizes the overhead in
 | 
			
		||||
@ -804,17 +832,17 @@ class NixlConnectorWorker:
 | 
			
		||||
            # of 'virtual' regions here and halve `block_len` below.
 | 
			
		||||
            self.num_regions *= 2
 | 
			
		||||
 | 
			
		||||
        kv_block_len = self.get_backend_aware_kv_block_len()
 | 
			
		||||
        # Register local/src descr for NIXL xfer.
 | 
			
		||||
        blocks_data = []
 | 
			
		||||
        for base_addr in seen_base_addresses:
 | 
			
		||||
        for i, base_addr in enumerate(seen_base_addresses):
 | 
			
		||||
            kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
 | 
			
		||||
            # NOTE With heter-TP, more blocks are prepared than what are
 | 
			
		||||
            # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
 | 
			
		||||
            # could create fewer, but then _get_block_descs_ids needs to
 | 
			
		||||
            # select agent_meta.num_blocks instead of self.num_blocks for
 | 
			
		||||
            # local descr, and that makes handling regular flow less clean.
 | 
			
		||||
            for block_id in range(self.num_blocks):
 | 
			
		||||
                block_offset = block_id * self.block_len
 | 
			
		||||
                block_offset = block_id * self.block_len_per_layer[i]
 | 
			
		||||
                addr = base_addr + block_offset
 | 
			
		||||
                # (addr, len, device id)
 | 
			
		||||
                blocks_data.append((addr, kv_block_len, self.tp_rank))
 | 
			
		||||
@ -824,7 +852,7 @@ class NixlConnectorWorker:
 | 
			
		||||
                # descs ordering. This is needed for selecting contiguous heads
 | 
			
		||||
                # when split across TP ranks.
 | 
			
		||||
                for block_id in range(self.num_blocks):
 | 
			
		||||
                    block_offset = block_id * self.block_len
 | 
			
		||||
                    block_offset = block_id * self.block_len_per_layer[i]
 | 
			
		||||
                    addr = base_addr + block_offset
 | 
			
		||||
                    # Register addresses for V cache (K registered first).
 | 
			
		||||
                    v_addr = addr + kv_block_len
 | 
			
		||||
@ -864,7 +892,7 @@ class NixlConnectorWorker:
 | 
			
		||||
            agent_metadata=self.nixl_wrapper.get_agent_metadata(),
 | 
			
		||||
            kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
 | 
			
		||||
            num_blocks=self.num_blocks,
 | 
			
		||||
            block_len=self.block_len,
 | 
			
		||||
            block_lens=self.block_len_per_layer,
 | 
			
		||||
            attn_backend_name=self.backend_name,
 | 
			
		||||
            kv_cache_layout=self.kv_cache_layout)
 | 
			
		||||
        ready_event = threading.Event()
 | 
			
		||||
@ -889,7 +917,7 @@ class NixlConnectorWorker:
 | 
			
		||||
        The latter, assuming D.world_size > P.world_size, requires that two or 
 | 
			
		||||
        more local TP worker share the xfer from a single TP worker.
 | 
			
		||||
 | 
			
		||||
        Here's an example:
 | 
			
		||||
        Here's an example (non-MLA case):
 | 
			
		||||
 | 
			
		||||
        rank_offset     p_remote_tp_rank
 | 
			
		||||
        (kv split no)    
 | 
			
		||||
@ -945,14 +973,20 @@ class NixlConnectorWorker:
 | 
			
		||||
        total_num_kv_heads = self.model_config.get_total_num_kv_heads()
 | 
			
		||||
        is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
 | 
			
		||||
 | 
			
		||||
        remote_block_len = nixl_agent_meta.block_lens[0]
 | 
			
		||||
        if self.use_mla or is_kv_replicated:
 | 
			
		||||
            # With MLA the only difference is in the number of blocks.
 | 
			
		||||
            remote_block_size = nixl_agent_meta.block_len // (
 | 
			
		||||
                self.slot_size_bytes)
 | 
			
		||||
            assert self.block_len == nixl_agent_meta.block_len
 | 
			
		||||
            # With replicated KV cache, only the number of blocks can differ.
 | 
			
		||||
            assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
 | 
			
		||||
                "KV cache sizes must match between P and D when replicated"
 | 
			
		||||
            remote_block_size = remote_block_len // (
 | 
			
		||||
                self.slot_size_per_layer[0])
 | 
			
		||||
        else:
 | 
			
		||||
            remote_block_size = nixl_agent_meta.block_len // (
 | 
			
		||||
                self.slot_size_bytes * tp_ratio)
 | 
			
		||||
            # When MLA is not used, this is a list of the same block length
 | 
			
		||||
            for block_len in nixl_agent_meta.block_lens:
 | 
			
		||||
                assert block_len == remote_block_len, \
 | 
			
		||||
                    "All remote layers must have the same block size"
 | 
			
		||||
            remote_block_size = remote_block_len // (
 | 
			
		||||
                self.slot_size_per_layer[0] * tp_ratio)
 | 
			
		||||
            if self._use_flashinfer:
 | 
			
		||||
                # With flashinfer, KV are sent in the same message.
 | 
			
		||||
                remote_block_size //= 2
 | 
			
		||||
@ -963,14 +997,14 @@ class NixlConnectorWorker:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "Heterogeneous TP is not supported on XPU")
 | 
			
		||||
 | 
			
		||||
            assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
 | 
			
		||||
            assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
 | 
			
		||||
                "Remote P worker KV layer cache must be of shape [2, N, "
 | 
			
		||||
                "local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        assert self.block_size == remote_block_size, (
 | 
			
		||||
            "Remote P worker with different block size is not supported "
 | 
			
		||||
            f"{self.block_size=} {remote_block_size=}")
 | 
			
		||||
            "Remote P worker with different page/block size is not supported "
 | 
			
		||||
            f"{self.block_size=}, {remote_block_size=}")
 | 
			
		||||
 | 
			
		||||
        # Create dst descs and xfer side handles. TP workers have same #blocks.
 | 
			
		||||
        if engine_id in self.dst_num_blocks:
 | 
			
		||||
@ -985,13 +1019,16 @@ class NixlConnectorWorker:
 | 
			
		||||
        # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
 | 
			
		||||
        self.kv_caches_base_addr[
 | 
			
		||||
            engine_id] = nixl_agent_meta.kv_caches_base_addr
 | 
			
		||||
        kv_block_len = self.get_backend_aware_kv_block_len()
 | 
			
		||||
        rank_offset = self.tp_rank % tp_ratio * kv_block_len \
 | 
			
		||||
            if not (self.use_mla or is_kv_replicated) else 0
 | 
			
		||||
 | 
			
		||||
        assert len(nixl_agent_meta.kv_caches_base_addr) == len(
 | 
			
		||||
            self.block_len_per_layer)
 | 
			
		||||
        # Register all remote blocks, but only the corresponding kv heads.
 | 
			
		||||
        for base_addr in nixl_agent_meta.kv_caches_base_addr:
 | 
			
		||||
        for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
 | 
			
		||||
            kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
 | 
			
		||||
            rank_offset = self.tp_rank % tp_ratio * kv_block_len \
 | 
			
		||||
                if not (self.use_mla or is_kv_replicated) else 0
 | 
			
		||||
            for block_id in range(nixl_agent_meta.num_blocks):
 | 
			
		||||
                block_offset = block_id * nixl_agent_meta.block_len
 | 
			
		||||
                block_offset = block_id * nixl_agent_meta.block_lens[i]
 | 
			
		||||
                # For each block, grab the heads chunk belonging to rank_i
 | 
			
		||||
                # of size remote_nheads // tp_ratio, which correspond to
 | 
			
		||||
                # self.block_len == remote_block_len//tp_ratio bytes.
 | 
			
		||||
@ -1002,9 +1039,9 @@ class NixlConnectorWorker:
 | 
			
		||||
            if self._use_flashinfer:
 | 
			
		||||
                # With FlashInfer index V separately to allow head splitting.
 | 
			
		||||
                for block_id in range(nixl_agent_meta.num_blocks):
 | 
			
		||||
                    block_offset = block_id * nixl_agent_meta.block_len
 | 
			
		||||
                    block_offset = block_id * nixl_agent_meta.block_lens[i]
 | 
			
		||||
                    addr = base_addr + block_offset + rank_offset
 | 
			
		||||
                    v_addr = addr + nixl_agent_meta.block_len // 2
 | 
			
		||||
                    v_addr = addr + nixl_agent_meta.block_lens[i] // 2
 | 
			
		||||
                    blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
 | 
			
		||||
 | 
			
		||||
        logger.debug(
 | 
			
		||||
@ -1082,6 +1119,7 @@ class NixlConnectorWorker:
 | 
			
		||||
                "Releasing expired KV blocks for request %s which were "
 | 
			
		||||
                "retrieved by %d decode worker(s) within %d seconds.", req_id,
 | 
			
		||||
                count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
 | 
			
		||||
            self._reqs_to_process.remove(req_id)
 | 
			
		||||
            del self._reqs_to_send[req_id]
 | 
			
		||||
            done_sending.add(req_id)
 | 
			
		||||
 | 
			
		||||
@ -1097,7 +1135,8 @@ class NixlConnectorWorker:
 | 
			
		||||
        for notifs in self.nixl_wrapper.get_new_notifs().values():
 | 
			
		||||
            for notif in notifs:
 | 
			
		||||
                req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
 | 
			
		||||
                if req_id not in self._reqs_to_send:
 | 
			
		||||
                if (req_id not in self._reqs_to_send
 | 
			
		||||
                        and req_id not in self._reqs_to_process):
 | 
			
		||||
                    logger.error(
 | 
			
		||||
                        "Potentially invalid KV blocks for "
 | 
			
		||||
                        "unrecognized request %s were retrieved by "
 | 
			
		||||
@ -1110,7 +1149,8 @@ class NixlConnectorWorker:
 | 
			
		||||
                        tp_ratio):
 | 
			
		||||
                    notified_req_ids.add(req_id)
 | 
			
		||||
                    del self.consumer_notification_counts_by_req[req_id]
 | 
			
		||||
                    del self._reqs_to_send[req_id]
 | 
			
		||||
                    self._reqs_to_process.remove(req_id)
 | 
			
		||||
                    self._reqs_to_send.pop(req_id, None)
 | 
			
		||||
        return notified_req_ids
 | 
			
		||||
 | 
			
		||||
    def _pop_done_transfers(
 | 
			
		||||
@ -1171,8 +1211,19 @@ class NixlConnectorWorker:
 | 
			
		||||
        while not self._ready_requests.empty():
 | 
			
		||||
            self._read_blocks_for_req(*self._ready_requests.get_nowait())
 | 
			
		||||
 | 
			
		||||
        # Keep around the requests that have been part of a batch. This is
 | 
			
		||||
        # needed because async scheduling pushes the misalignment between the
 | 
			
		||||
        # moment in which requests expiration is set (P side) and the moment in
 | 
			
		||||
        # which blocks are read from D. As P can now more easily lag behind D
 | 
			
		||||
        # while processing the next batch, we make sure to only set an
 | 
			
		||||
        # expiration for requests that have not been read from D yet.
 | 
			
		||||
        for req_id in metadata.reqs_in_batch:
 | 
			
		||||
            self._reqs_to_process.add(req_id)
 | 
			
		||||
 | 
			
		||||
        # Add to requests that are waiting to be read and track expiration.
 | 
			
		||||
        self._reqs_to_send.update(metadata.reqs_to_send)
 | 
			
		||||
        for req_id, expiration_time in metadata.reqs_to_send.items():
 | 
			
		||||
            if req_id in self._reqs_to_process:
 | 
			
		||||
                self._reqs_to_send[req_id] = expiration_time
 | 
			
		||||
 | 
			
		||||
    def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
 | 
			
		||||
        logger.debug(
 | 
			
		||||
@ -1317,7 +1368,7 @@ class NixlConnectorWorker:
 | 
			
		||||
        descs_ids = region_ids * num_blocks + block_ids
 | 
			
		||||
        return descs_ids.flatten()
 | 
			
		||||
 | 
			
		||||
    def get_backend_aware_kv_block_len(self):
 | 
			
		||||
    def get_backend_aware_kv_block_len(self, layer_idx: int):
 | 
			
		||||
        """
 | 
			
		||||
        Get the block length for one K/V element (K and V have the same size).
 | 
			
		||||
 | 
			
		||||
@ -1328,9 +1379,9 @@ class NixlConnectorWorker:
 | 
			
		||||
        """
 | 
			
		||||
        if self._use_flashinfer:
 | 
			
		||||
            # For indexing only half (either just the K or V part).
 | 
			
		||||
            block_len = self.block_len // 2
 | 
			
		||||
            block_len = self.block_len_per_layer[layer_idx] // 2
 | 
			
		||||
        else:
 | 
			
		||||
            block_len = self.block_len
 | 
			
		||||
            block_len = self.block_len_per_layer[layer_idx]
 | 
			
		||||
        return block_len
 | 
			
		||||
 | 
			
		||||
    def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
 | 
			
		||||
 | 
			
		||||
@ -871,17 +871,24 @@ class GroupCoordinator:
 | 
			
		||||
                model)
 | 
			
		||||
 | 
			
		||||
    def dispatch(
 | 
			
		||||
            self, hidden_states: torch.Tensor,
 | 
			
		||||
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        self,
 | 
			
		||||
        hidden_states: torch.Tensor,
 | 
			
		||||
        router_logits: torch.Tensor,
 | 
			
		||||
        is_sequence_parallel: bool = False
 | 
			
		||||
    ) -> tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        if self.device_communicator is not None:
 | 
			
		||||
            return self.device_communicator.dispatch(hidden_states,
 | 
			
		||||
                                                     router_logits)
 | 
			
		||||
                                                     router_logits,
 | 
			
		||||
                                                     is_sequence_parallel)
 | 
			
		||||
        else:
 | 
			
		||||
            return hidden_states, router_logits
 | 
			
		||||
 | 
			
		||||
    def combine(self, hidden_states) -> torch.Tensor:
 | 
			
		||||
    def combine(self,
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                is_sequence_parallel: bool = False) -> torch.Tensor:
 | 
			
		||||
        if self.device_communicator is not None:
 | 
			
		||||
            return self.device_communicator.combine(hidden_states)
 | 
			
		||||
            return self.device_communicator.combine(hidden_states,
 | 
			
		||||
                                                    is_sequence_parallel)
 | 
			
		||||
        else:
 | 
			
		||||
            return hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -297,6 +297,8 @@ class EngineArgs:
 | 
			
		||||
    tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
 | 
			
		||||
    trust_remote_code: bool = ModelConfig.trust_remote_code
 | 
			
		||||
    allowed_local_media_path: str = ModelConfig.allowed_local_media_path
 | 
			
		||||
    allowed_media_domains: Optional[
 | 
			
		||||
        list[str]] = ModelConfig.allowed_media_domains
 | 
			
		||||
    download_dir: Optional[str] = LoadConfig.download_dir
 | 
			
		||||
    safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
 | 
			
		||||
    load_format: Union[str, LoadFormats] = LoadConfig.load_format
 | 
			
		||||
@ -531,6 +533,8 @@ class EngineArgs:
 | 
			
		||||
                                 **model_kwargs["hf_config_path"])
 | 
			
		||||
        model_group.add_argument("--allowed-local-media-path",
 | 
			
		||||
                                 **model_kwargs["allowed_local_media_path"])
 | 
			
		||||
        model_group.add_argument("--allowed-media-domains",
 | 
			
		||||
                                 **model_kwargs["allowed_media_domains"])
 | 
			
		||||
        model_group.add_argument("--revision", **model_kwargs["revision"])
 | 
			
		||||
        model_group.add_argument("--code-revision",
 | 
			
		||||
                                 **model_kwargs["code_revision"])
 | 
			
		||||
@ -997,6 +1001,7 @@ class EngineArgs:
 | 
			
		||||
            tokenizer_mode=self.tokenizer_mode,
 | 
			
		||||
            trust_remote_code=self.trust_remote_code,
 | 
			
		||||
            allowed_local_media_path=self.allowed_local_media_path,
 | 
			
		||||
            allowed_media_domains=self.allowed_media_domains,
 | 
			
		||||
            dtype=self.dtype,
 | 
			
		||||
            seed=self.seed,
 | 
			
		||||
            revision=self.revision,
 | 
			
		||||
@ -1481,7 +1486,7 @@ class EngineArgs:
 | 
			
		||||
                raise NotImplementedError(
 | 
			
		||||
                    "Draft model speculative decoding is not supported yet. "
 | 
			
		||||
                    "Please consider using other speculative decoding methods "
 | 
			
		||||
                    "such as ngram, medusa, eagle, or deepseek_mtp.")
 | 
			
		||||
                    "such as ngram, medusa, eagle, or mtp.")
 | 
			
		||||
 | 
			
		||||
        V1_BACKENDS = [
 | 
			
		||||
            "FLASH_ATTN",
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,12 @@ from pathlib import Path
 | 
			
		||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
 | 
			
		||||
                    cast)
 | 
			
		||||
 | 
			
		||||
import jinja2
 | 
			
		||||
import jinja2.ext
 | 
			
		||||
import jinja2.meta
 | 
			
		||||
import jinja2.nodes
 | 
			
		||||
import jinja2.parser
 | 
			
		||||
import jinja2.sandbox
 | 
			
		||||
import transformers.utils.chat_template_utils as hf_chat_utils
 | 
			
		||||
# yapf conflicts with isort for this block
 | 
			
		||||
# yapf: disable
 | 
			
		||||
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
 | 
			
		||||
# yapf: enable
 | 
			
		||||
from vllm.transformers_utils.processor import cached_get_processor
 | 
			
		||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
 | 
			
		||||
from vllm.utils import random_uuid
 | 
			
		||||
from vllm.utils import random_uuid, supports_kw
 | 
			
		||||
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
 | 
			
		||||
    def allowed_local_media_path(self):
 | 
			
		||||
        return self._model_config.allowed_local_media_path
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def allowed_media_domains(self):
 | 
			
		||||
        return self._model_config.allowed_media_domains
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def mm_registry(self):
 | 
			
		||||
        return MULTIMODAL_REGISTRY
 | 
			
		||||
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
 | 
			
		||||
        self._connector = MediaConnector(
 | 
			
		||||
            media_io_kwargs=media_io_kwargs,
 | 
			
		||||
            allowed_local_media_path=tracker.allowed_local_media_path,
 | 
			
		||||
            allowed_media_domains=tracker.allowed_media_domains,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def parse_image(
 | 
			
		||||
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
 | 
			
		||||
        self._connector = MediaConnector(
 | 
			
		||||
            media_io_kwargs=media_io_kwargs,
 | 
			
		||||
            allowed_local_media_path=tracker.allowed_local_media_path,
 | 
			
		||||
            allowed_media_domains=tracker.allowed_media_domains,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def parse_image(
 | 
			
		||||
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
 | 
			
		||||
    return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
 | 
			
		||||
# only preserve the parse function used to resolve chat template kwargs
 | 
			
		||||
class AssistantTracker(jinja2.ext.Extension):
 | 
			
		||||
    tags = {"generation"}
 | 
			
		||||
 | 
			
		||||
    def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
 | 
			
		||||
        lineno = next(parser.stream).lineno
 | 
			
		||||
        body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
 | 
			
		||||
        call = self.call_method("_generation_support")
 | 
			
		||||
        call_block = jinja2.nodes.CallBlock(call, [], [], body)
 | 
			
		||||
        return call_block.set_lineno(lineno)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def resolve_chat_template_kwargs(
 | 
			
		||||
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
 | 
			
		||||
    chat_template: str,
 | 
			
		||||
    chat_template_kwargs: dict[str, Any],
 | 
			
		||||
) -> dict[str, Any]:
 | 
			
		||||
    fn_kw = {
 | 
			
		||||
        k for k in chat_template_kwargs
 | 
			
		||||
        if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    env = jinja2.sandbox.ImmutableSandboxedEnvironment(
 | 
			
		||||
        trim_blocks=True,
 | 
			
		||||
        lstrip_blocks=True,
 | 
			
		||||
        extensions=[AssistantTracker, jinja2.ext.loopcontrols],
 | 
			
		||||
    )
 | 
			
		||||
    parsed_content = env.parse(chat_template)
 | 
			
		||||
    template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
 | 
			
		||||
 | 
			
		||||
    # We exclude chat_template from kwargs here, because
 | 
			
		||||
    # chat template has been already resolved at this stage
 | 
			
		||||
    unexpected_vars = {"chat_template"}
 | 
			
		||||
    accept_vars = (fn_kw | template_vars) - unexpected_vars
 | 
			
		||||
    return {
 | 
			
		||||
        k: v for k, v in chat_template_kwargs.items() if k in accept_vars
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_hf_chat_template(
 | 
			
		||||
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
 | 
			
		||||
    conversation: list[ConversationMessage],
 | 
			
		||||
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        resolved_kwargs = resolve_chat_template_kwargs(
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            chat_template=hf_chat_template,
 | 
			
		||||
            chat_template_kwargs=kwargs,
 | 
			
		||||
        )
 | 
			
		||||
        return tokenizer.apply_chat_template(
 | 
			
		||||
            conversation=conversation,  # type: ignore[arg-type]
 | 
			
		||||
            tools=tools,  # type: ignore[arg-type]
 | 
			
		||||
            chat_template=hf_chat_template,
 | 
			
		||||
            tokenize=tokenize,
 | 
			
		||||
            **kwargs,
 | 
			
		||||
            **resolved_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # External library exceptions can sometimes occur despite the framework's
 | 
			
		||||
 | 
			
		||||
@ -86,6 +86,8 @@ class LLM:
 | 
			
		||||
            or videos from directories specified by the server file system.
 | 
			
		||||
            This is a security risk. Should only be enabled in trusted
 | 
			
		||||
            environments.
 | 
			
		||||
        allowed_media_domains: If set, only media URLs that belong to this 
 | 
			
		||||
            domain can be used for multi-modal inputs.
 | 
			
		||||
        tensor_parallel_size: The number of GPUs to use for distributed
 | 
			
		||||
            execution with tensor parallelism.
 | 
			
		||||
        dtype: The data type for the model weights and activations. Currently,
 | 
			
		||||
@ -169,6 +171,7 @@ class LLM:
 | 
			
		||||
        skip_tokenizer_init: bool = False,
 | 
			
		||||
        trust_remote_code: bool = False,
 | 
			
		||||
        allowed_local_media_path: str = "",
 | 
			
		||||
        allowed_media_domains: Optional[list[str]] = None,
 | 
			
		||||
        tensor_parallel_size: int = 1,
 | 
			
		||||
        dtype: ModelDType = "auto",
 | 
			
		||||
        quantization: Optional[QuantizationMethods] = None,
 | 
			
		||||
@ -264,6 +267,7 @@ class LLM:
 | 
			
		||||
            skip_tokenizer_init=skip_tokenizer_init,
 | 
			
		||||
            trust_remote_code=trust_remote_code,
 | 
			
		||||
            allowed_local_media_path=allowed_local_media_path,
 | 
			
		||||
            allowed_media_domains=allowed_media_domains,
 | 
			
		||||
            tensor_parallel_size=tensor_parallel_size,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
            quantization=quantization,
 | 
			
		||||
 | 
			
		||||
@ -3,12 +3,14 @@
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import gc
 | 
			
		||||
import hashlib
 | 
			
		||||
import importlib
 | 
			
		||||
import inspect
 | 
			
		||||
import json
 | 
			
		||||
import multiprocessing
 | 
			
		||||
import multiprocessing.forkserver as forkserver
 | 
			
		||||
import os
 | 
			
		||||
import secrets
 | 
			
		||||
import signal
 | 
			
		||||
import socket
 | 
			
		||||
import tempfile
 | 
			
		||||
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
 | 
			
		||||
class AuthenticationMiddleware:
 | 
			
		||||
    """
 | 
			
		||||
    Pure ASGI middleware that authenticates each request by checking
 | 
			
		||||
    if the Authorization header exists and equals "Bearer {api_key}".
 | 
			
		||||
    if the Authorization Bearer token exists and equals anyof "{api_key}".
 | 
			
		||||
 | 
			
		||||
    Notes
 | 
			
		||||
    -----
 | 
			
		||||
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
 | 
			
		||||
        self.app = app
 | 
			
		||||
        self.api_tokens = {f"Bearer {token}" for token in tokens}
 | 
			
		||||
        self.api_tokens = [
 | 
			
		||||
            hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def verify_token(self, headers: Headers) -> bool:
 | 
			
		||||
        authorization_header_value = headers.get("Authorization")
 | 
			
		||||
        if not authorization_header_value:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        scheme, _, param = authorization_header_value.partition(" ")
 | 
			
		||||
        if scheme.lower() != "bearer":
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        param_hash = hashlib.sha256(param.encode("utf-8")).digest()
 | 
			
		||||
 | 
			
		||||
        token_match = False
 | 
			
		||||
        for token_hash in self.api_tokens:
 | 
			
		||||
            token_match |= secrets.compare_digest(param_hash, token_hash)
 | 
			
		||||
 | 
			
		||||
        return token_match
 | 
			
		||||
 | 
			
		||||
    def __call__(self, scope: Scope, receive: Receive,
 | 
			
		||||
                 send: Send) -> Awaitable[None]:
 | 
			
		||||
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
 | 
			
		||||
        url_path = URL(scope=scope).path.removeprefix(root_path)
 | 
			
		||||
        headers = Headers(scope=scope)
 | 
			
		||||
        # Type narrow to satisfy mypy.
 | 
			
		||||
        if url_path.startswith("/v1") and headers.get(
 | 
			
		||||
                "Authorization") not in self.api_tokens:
 | 
			
		||||
        if url_path.startswith("/v1") and not self.verify_token(headers):
 | 
			
		||||
            response = JSONResponse(content={"error": "Unauthorized"},
 | 
			
		||||
                                    status_code=401)
 | 
			
		||||
            return response(scope, receive, send)
 | 
			
		||||
@ -1696,6 +1716,7 @@ async def init_app_state(
 | 
			
		||||
        request_logger=request_logger,
 | 
			
		||||
        chat_template=resolved_chat_template,
 | 
			
		||||
        chat_template_content_format=args.chat_template_content_format,
 | 
			
		||||
        trust_request_chat_template=args.trust_request_chat_template,
 | 
			
		||||
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
 | 
			
		||||
        enable_auto_tools=args.enable_auto_tool_choice,
 | 
			
		||||
        exclude_tools_when_tool_choice_none=args.
 | 
			
		||||
 | 
			
		||||
@ -103,9 +103,13 @@ class FrontendArgs:
 | 
			
		||||
    chat_template_content_format: ChatTemplateContentFormatOption = "auto"
 | 
			
		||||
    """The format to render message content within a chat template.
 | 
			
		||||
 | 
			
		||||
* "string" will render the content as a string. Example: `"Hello World"`
 | 
			
		||||
* "openai" will render the content as a list of dictionaries, similar to OpenAI
 | 
			
		||||
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
 | 
			
		||||
    * "string" will render the content as a string. Example: `"Hello World"`
 | 
			
		||||
    * "openai" will render the content as a list of dictionaries, similar to
 | 
			
		||||
      OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
 | 
			
		||||
    trust_request_chat_template: bool = False
 | 
			
		||||
    """Whether to trust the chat template provided in the request. If False,
 | 
			
		||||
    the server will always use the chat template specified by `--chat-template`
 | 
			
		||||
    or the ones from tokenizer."""
 | 
			
		||||
    response_role: str = "assistant"
 | 
			
		||||
    """The role name to return if `request.add_generation_prompt=true`."""
 | 
			
		||||
    ssl_keyfile: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
 | 
			
		||||
        request_logger: Optional[RequestLogger],
 | 
			
		||||
        chat_template: Optional[str],
 | 
			
		||||
        chat_template_content_format: ChatTemplateContentFormatOption,
 | 
			
		||||
        trust_request_chat_template: bool = False,
 | 
			
		||||
        return_tokens_as_token_ids: bool = False,
 | 
			
		||||
        reasoning_parser: str = "",
 | 
			
		||||
        enable_auto_tools: bool = False,
 | 
			
		||||
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
 | 
			
		||||
        self.response_role = response_role
 | 
			
		||||
        self.chat_template = chat_template
 | 
			
		||||
        self.chat_template_content_format: Final = chat_template_content_format
 | 
			
		||||
        self.trust_request_chat_template = trust_request_chat_template
 | 
			
		||||
        self.enable_log_outputs = enable_log_outputs
 | 
			
		||||
 | 
			
		||||
        # set up tool use
 | 
			
		||||
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
 | 
			
		||||
 | 
			
		||||
            if not self.use_harmony:
 | 
			
		||||
                # Common case.
 | 
			
		||||
                request_chat_template = request.chat_template
 | 
			
		||||
                chat_template_kwargs = request.chat_template_kwargs
 | 
			
		||||
                if not self.trust_request_chat_template and (
 | 
			
		||||
                        request_chat_template is not None or
 | 
			
		||||
                    (chat_template_kwargs and
 | 
			
		||||
                     chat_template_kwargs.get("chat_template") is not None)):
 | 
			
		||||
                    return self.create_error_response(
 | 
			
		||||
                        "Chat template is passed with request, but "
 | 
			
		||||
                        "--trust-request-chat-template is not set. "
 | 
			
		||||
                        "Refused request with untrusted chat template.")
 | 
			
		||||
                (
 | 
			
		||||
                    conversation,
 | 
			
		||||
                    request_prompts,
 | 
			
		||||
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
 | 
			
		||||
                    request,
 | 
			
		||||
                    tokenizer,
 | 
			
		||||
                    request.messages,
 | 
			
		||||
                    chat_template=request.chat_template or self.chat_template,
 | 
			
		||||
                    chat_template=request_chat_template or self.chat_template,
 | 
			
		||||
                    chat_template_content_format=self.
 | 
			
		||||
                    chat_template_content_format,
 | 
			
		||||
                    add_generation_prompt=request.add_generation_prompt,
 | 
			
		||||
 | 
			
		||||
@ -68,6 +68,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    VLLM_IMAGE_FETCH_TIMEOUT: int = 5
 | 
			
		||||
    VLLM_VIDEO_FETCH_TIMEOUT: int = 30
 | 
			
		||||
    VLLM_AUDIO_FETCH_TIMEOUT: int = 10
 | 
			
		||||
    VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
 | 
			
		||||
    VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
 | 
			
		||||
    VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
 | 
			
		||||
    VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
 | 
			
		||||
@ -725,6 +726,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
 | 
			
		||||
    "VLLM_AUDIO_FETCH_TIMEOUT":
 | 
			
		||||
    lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
 | 
			
		||||
 | 
			
		||||
    # Whether to allow HTTP redirects when fetching from media URLs.
 | 
			
		||||
    # Default to True
 | 
			
		||||
    "VLLM_MEDIA_URL_ALLOW_REDIRECTS":
 | 
			
		||||
    lambda: bool(int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1"))),
 | 
			
		||||
 | 
			
		||||
    # Max number of workers for the thread pool handling
 | 
			
		||||
    # media bytes loading. Set to 1 to disable parallel processing.
 | 
			
		||||
    # Default is 8
 | 
			
		||||
 | 
			
		||||
@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
 | 
			
		||||
        return BatchDescriptor(self.num_tokens, uniform_decode=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
 | 
			
		||||
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
 | 
			
		||||
                           sequence_parallel_size: int) -> list[int]:
 | 
			
		||||
    sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
 | 
			
		||||
                 sequence_parallel_size)
 | 
			
		||||
 | 
			
		||||
    sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
 | 
			
		||||
    return sp_tokens.tolist()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
 | 
			
		||||
                                      sequence_parallel_size: int,
 | 
			
		||||
                                      max_num_tokens: int,
 | 
			
		||||
                                      chunk_idx: int) -> list[int]:
 | 
			
		||||
    dp_size = len(num_tokens_across_dp_cpu)
 | 
			
		||||
 | 
			
		||||
    local_size = [-1] * dp_size
 | 
			
		||||
    for i in range(dp_size):
 | 
			
		||||
        dp_tokens = num_tokens_across_dp_cpu[i]
 | 
			
		||||
    sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
 | 
			
		||||
                                       sequence_parallel_size)
 | 
			
		||||
    sp_size = len(sp_tokens)
 | 
			
		||||
 | 
			
		||||
    local_size = [-1] * sp_size
 | 
			
		||||
    for i in range(sp_size):
 | 
			
		||||
        # Take into account sharding if MoE activation is sequence parallel.
 | 
			
		||||
        local_size[i] = min(max_num_tokens,
 | 
			
		||||
                            dp_tokens - (max_num_tokens * chunk_idx))
 | 
			
		||||
                            sp_tokens[i] - (max_num_tokens * chunk_idx))
 | 
			
		||||
        if local_size[i] <= 0:
 | 
			
		||||
            local_size[i] = 1  # ensure lockstep even if done
 | 
			
		||||
    return local_size
 | 
			
		||||
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
 | 
			
		||||
@dataclass
 | 
			
		||||
class DPMetadata:
 | 
			
		||||
    max_tokens_across_dp_cpu: torch.Tensor
 | 
			
		||||
    cu_tokens_across_dp_cpu: torch.Tensor
 | 
			
		||||
    num_tokens_across_dp_cpu: torch.Tensor
 | 
			
		||||
 | 
			
		||||
    # NOTE: local_sizes should only be set by the chunked_sizes context manager
 | 
			
		||||
    local_sizes: Optional[list[int]] = None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -98,6 +113,17 @@ class DPMetadata:
 | 
			
		||||
        dist.all_reduce(num_tokens_tensor, group=group)
 | 
			
		||||
        return num_tokens_tensor.cpu()
 | 
			
		||||
 | 
			
		||||
    # Get the cumulative tokens across sequence parallel ranks.
 | 
			
		||||
    # In this case the input to the MoEs will be distributed w.r.t both
 | 
			
		||||
    # DP and TP rank.
 | 
			
		||||
    # When sp_size==1, this is just the cummulative num tokens across DP.
 | 
			
		||||
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
 | 
			
		||||
        num_tokens_across_sp_cpu = (
 | 
			
		||||
            (self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
 | 
			
		||||
        num_tokens_across_sp_cpu = (
 | 
			
		||||
            num_tokens_across_sp_cpu.repeat_interleave(sp_size))
 | 
			
		||||
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def should_ubatch_across_dp(
 | 
			
		||||
            should_ubatch: bool, orig_num_tokens_per_ubatch: int,
 | 
			
		||||
@ -147,10 +173,10 @@ class DPMetadata:
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def make(
 | 
			
		||||
            parallel_config: ParallelConfig,
 | 
			
		||||
            attn_metadata: Any,
 | 
			
		||||
            num_tokens: int,
 | 
			
		||||
            num_tokens_across_dp: Optional[torch.Tensor] = None
 | 
			
		||||
        parallel_config: ParallelConfig,
 | 
			
		||||
        attn_metadata: Any,
 | 
			
		||||
        num_tokens: int,
 | 
			
		||||
        num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
 | 
			
		||||
    ) -> "DPMetadata":
 | 
			
		||||
 | 
			
		||||
        assert parallel_config.data_parallel_size > 1
 | 
			
		||||
@ -167,18 +193,18 @@ class DPMetadata:
 | 
			
		||||
 | 
			
		||||
        # If num_tokens_across_dp is None, it will be computed by all_reduce
 | 
			
		||||
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
 | 
			
		||||
        assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
 | 
			
		||||
                == batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
 | 
			
		||||
        if num_tokens_across_dp is None:
 | 
			
		||||
            num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
 | 
			
		||||
        assert (num_tokens_across_dp_cpu is None
 | 
			
		||||
                or num_tokens_across_dp_cpu[dp_rank] == batchsize
 | 
			
		||||
                ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
 | 
			
		||||
        if num_tokens_across_dp_cpu is None:
 | 
			
		||||
            num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
 | 
			
		||||
                batchsize, dp_size, dp_rank)
 | 
			
		||||
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
 | 
			
		||||
        cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
 | 
			
		||||
        return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
 | 
			
		||||
                          num_tokens_across_dp)
 | 
			
		||||
        max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
 | 
			
		||||
        return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
 | 
			
		||||
    def chunked_sizes(self, sequence_parallel_size: int,
 | 
			
		||||
                      max_chunk_size_per_rank: int, chunk_idx: int):
 | 
			
		||||
        """
 | 
			
		||||
        Context manager to compute and temporarily set the per-rank local token
 | 
			
		||||
        sizes for a specific chunk during chunked forward execution.
 | 
			
		||||
@ -192,31 +218,40 @@ class DPMetadata:
 | 
			
		||||
        `chunk_idx`, this context manager sets `self.local_sizes` to the number
 | 
			
		||||
        of tokens to process in that chunk on each rank.
 | 
			
		||||
 | 
			
		||||
        It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
 | 
			
		||||
        number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
 | 
			
		||||
        to determine the chunk-wise split.
 | 
			
		||||
 | 
			
		||||
        `self.local_sizes` is only valid inside the context.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            sequence_parallel_size: When Attn is TP and MoE layers are EP,
 | 
			
		||||
                                    we use SP between the layers to avoid
 | 
			
		||||
                                    redundant ops. We need this value to
 | 
			
		||||
                                    compute the chunked sizes.
 | 
			
		||||
            max_chunk_size_per_rank: The max number of tokens each rank is 
 | 
			
		||||
                                     allowed to process in this chunk.
 | 
			
		||||
            chunk_idx: The index of the chunk to compute sizes for.
 | 
			
		||||
        """
 | 
			
		||||
        cu_sizes = self.cu_tokens_across_dp_cpu
 | 
			
		||||
        num_tokens_across_dp_cpu = [
 | 
			
		||||
            (cu_sizes[i] -
 | 
			
		||||
             cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
 | 
			
		||||
            for i in range(len(cu_sizes))
 | 
			
		||||
        ]
 | 
			
		||||
        self.local_sizes = _compute_chunked_local_num_tokens(
 | 
			
		||||
            num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
 | 
			
		||||
            self.num_tokens_across_dp_cpu, sequence_parallel_size,
 | 
			
		||||
            max_chunk_size_per_rank, chunk_idx)
 | 
			
		||||
        try:
 | 
			
		||||
            yield self.local_sizes
 | 
			
		||||
        finally:
 | 
			
		||||
            self.local_sizes = None
 | 
			
		||||
 | 
			
		||||
    @contextmanager
 | 
			
		||||
    def sp_local_sizes(self, sequence_parallel_size: int):
 | 
			
		||||
        """
 | 
			
		||||
        Context mamager for setting self.local_sizes. Same as self.chunked_sizes
 | 
			
		||||
        but without any chunking.
 | 
			
		||||
        """
 | 
			
		||||
        self.local_sizes = _compute_sp_num_tokens(
 | 
			
		||||
            self.num_tokens_across_dp_cpu, sequence_parallel_size)
 | 
			
		||||
        try:
 | 
			
		||||
            yield self.local_sizes
 | 
			
		||||
        finally:
 | 
			
		||||
            self.local_sizes = None
 | 
			
		||||
 | 
			
		||||
    def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
 | 
			
		||||
        assert self.local_sizes is not None
 | 
			
		||||
        return self.local_sizes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@
 | 
			
		||||
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from collections.abc import Iterable
 | 
			
		||||
from contextlib import nullcontext
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Callable, Literal, Optional, Union, get_args, overload
 | 
			
		||||
 | 
			
		||||
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
 | 
			
		||||
                    if dp_size is not None else get_dp_group().world_size)
 | 
			
		||||
 | 
			
		||||
        self.is_sequence_parallel = is_sequence_parallel
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            self.sp_size = tp_size_
 | 
			
		||||
        self.sp_size = tp_size_ if is_sequence_parallel else 1
 | 
			
		||||
 | 
			
		||||
        self.moe_parallel_config: FusedMoEParallelConfig = (
 | 
			
		||||
            FusedMoEParallelConfig.make(
 | 
			
		||||
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
 | 
			
		||||
            # clamp start and end
 | 
			
		||||
            chunk_start = min(chunk_start, num_tokens - 1)
 | 
			
		||||
            chunk_end = min(chunk_end, num_tokens)
 | 
			
		||||
            with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
 | 
			
		||||
            with ctx.dp_metadata.chunked_sizes(self.sp_size,
 | 
			
		||||
                                               moe_dp_chunk_size_per_rank,
 | 
			
		||||
                                               chunk_idx):
 | 
			
		||||
                process_chunk(chunk_start,
 | 
			
		||||
                              chunk_end,
 | 
			
		||||
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
 | 
			
		||||
        else:
 | 
			
		||||
            shared_output = None
 | 
			
		||||
 | 
			
		||||
        if do_naive_dispatch_combine:
 | 
			
		||||
            hidden_states, router_logits = get_ep_group().dispatch(
 | 
			
		||||
                hidden_states, router_logits)
 | 
			
		||||
        ctx = get_forward_context()
 | 
			
		||||
        sp_ctx = ctx.dp_metadata.sp_local_sizes(
 | 
			
		||||
            self.sp_size) if ctx.dp_metadata else nullcontext()
 | 
			
		||||
 | 
			
		||||
        # Matrix multiply.
 | 
			
		||||
        final_hidden_states = self.quant_method.apply(
 | 
			
		||||
            layer=self,
 | 
			
		||||
            x=hidden_states,
 | 
			
		||||
            router_logits=router_logits,
 | 
			
		||||
            top_k=self.top_k,
 | 
			
		||||
            renormalize=self.renormalize,
 | 
			
		||||
            use_grouped_topk=self.use_grouped_topk,
 | 
			
		||||
            global_num_experts=self.global_num_experts,
 | 
			
		||||
            expert_map=self.expert_map,
 | 
			
		||||
            topk_group=self.topk_group,
 | 
			
		||||
            num_expert_group=self.num_expert_group,
 | 
			
		||||
            custom_routing_function=self.custom_routing_function,
 | 
			
		||||
            scoring_func=self.scoring_func,
 | 
			
		||||
            routed_scaling_factor=self.routed_scaling_factor,
 | 
			
		||||
            e_score_correction_bias=self.e_score_correction_bias,
 | 
			
		||||
            activation=self.activation,
 | 
			
		||||
            apply_router_weight_on_input=self.apply_router_weight_on_input,
 | 
			
		||||
            enable_eplb=self.enable_eplb,
 | 
			
		||||
            expert_load_view=self.expert_load_view,
 | 
			
		||||
            logical_to_physical_map=self.logical_to_physical_map,
 | 
			
		||||
            logical_replica_count=self.logical_replica_count,
 | 
			
		||||
        )
 | 
			
		||||
        with sp_ctx:
 | 
			
		||||
            if do_naive_dispatch_combine:
 | 
			
		||||
                hidden_states, router_logits = get_ep_group().dispatch(
 | 
			
		||||
                    hidden_states, router_logits, self.is_sequence_parallel)
 | 
			
		||||
 | 
			
		||||
        if shared_output is not None:
 | 
			
		||||
            assert not isinstance(final_hidden_states, tuple)
 | 
			
		||||
            assert self.shared_experts is not None
 | 
			
		||||
            final_hidden_states = (
 | 
			
		||||
                shared_output,
 | 
			
		||||
                final_hidden_states,
 | 
			
		||||
            # Matrix multiply.
 | 
			
		||||
            final_hidden_states = self.quant_method.apply(
 | 
			
		||||
                layer=self,
 | 
			
		||||
                x=hidden_states,
 | 
			
		||||
                router_logits=router_logits,
 | 
			
		||||
                top_k=self.top_k,
 | 
			
		||||
                renormalize=self.renormalize,
 | 
			
		||||
                use_grouped_topk=self.use_grouped_topk,
 | 
			
		||||
                global_num_experts=self.global_num_experts,
 | 
			
		||||
                expert_map=self.expert_map,
 | 
			
		||||
                topk_group=self.topk_group,
 | 
			
		||||
                num_expert_group=self.num_expert_group,
 | 
			
		||||
                custom_routing_function=self.custom_routing_function,
 | 
			
		||||
                scoring_func=self.scoring_func,
 | 
			
		||||
                routed_scaling_factor=self.routed_scaling_factor,
 | 
			
		||||
                e_score_correction_bias=self.e_score_correction_bias,
 | 
			
		||||
                activation=self.activation,
 | 
			
		||||
                apply_router_weight_on_input=self.apply_router_weight_on_input,
 | 
			
		||||
                enable_eplb=self.enable_eplb,
 | 
			
		||||
                expert_load_view=self.expert_load_view,
 | 
			
		||||
                logical_to_physical_map=self.logical_to_physical_map,
 | 
			
		||||
                logical_replica_count=self.logical_replica_count,
 | 
			
		||||
            )
 | 
			
		||||
        elif self.zero_expert_num is not None and self.zero_expert_num > 0:
 | 
			
		||||
            assert isinstance(final_hidden_states, tuple)
 | 
			
		||||
            final_hidden_states, zero_expert_result = final_hidden_states
 | 
			
		||||
 | 
			
		||||
        def reduce_output(states: torch.Tensor,
 | 
			
		||||
                          do_combine: bool = True) -> torch.Tensor:
 | 
			
		||||
            if do_naive_dispatch_combine and do_combine:
 | 
			
		||||
                states = get_ep_group().combine(states)
 | 
			
		||||
            if shared_output is not None:
 | 
			
		||||
                assert not isinstance(final_hidden_states, tuple)
 | 
			
		||||
                assert self.shared_experts is not None
 | 
			
		||||
                final_hidden_states = (
 | 
			
		||||
                    shared_output,
 | 
			
		||||
                    final_hidden_states,
 | 
			
		||||
                )
 | 
			
		||||
            elif self.zero_expert_num is not None and self.zero_expert_num > 0:
 | 
			
		||||
                assert isinstance(final_hidden_states, tuple)
 | 
			
		||||
                final_hidden_states, zero_expert_result = final_hidden_states
 | 
			
		||||
 | 
			
		||||
            if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
 | 
			
		||||
                states = self.maybe_all_reduce_tensor_model_parallel(states)
 | 
			
		||||
            def reduce_output(states: torch.Tensor,
 | 
			
		||||
                              do_combine: bool = True) -> torch.Tensor:
 | 
			
		||||
                if do_naive_dispatch_combine and do_combine:
 | 
			
		||||
                    states = get_ep_group().combine(states,
 | 
			
		||||
                                                    self.is_sequence_parallel)
 | 
			
		||||
 | 
			
		||||
            return states
 | 
			
		||||
                if (not self.is_sequence_parallel and self.reduce_results
 | 
			
		||||
                        and (self.tp_size > 1 or self.ep_size > 1)):
 | 
			
		||||
                    states = self.maybe_all_reduce_tensor_model_parallel(
 | 
			
		||||
                        states)
 | 
			
		||||
 | 
			
		||||
        if self.shared_experts is not None:
 | 
			
		||||
            return (
 | 
			
		||||
                reduce_output(final_hidden_states[0], do_combine=False),
 | 
			
		||||
                reduce_output(final_hidden_states[1]),
 | 
			
		||||
            )
 | 
			
		||||
        elif self.zero_expert_num is not None and self.zero_expert_num > 0:
 | 
			
		||||
            assert isinstance(final_hidden_states, torch.Tensor)
 | 
			
		||||
            return reduce_output(final_hidden_states) + zero_expert_result
 | 
			
		||||
        else:
 | 
			
		||||
            return reduce_output(final_hidden_states)
 | 
			
		||||
                return states
 | 
			
		||||
 | 
			
		||||
            if self.shared_experts is not None:
 | 
			
		||||
                return (
 | 
			
		||||
                    reduce_output(final_hidden_states[0], do_combine=False),
 | 
			
		||||
                    reduce_output(final_hidden_states[1]),
 | 
			
		||||
                )
 | 
			
		||||
            elif self.zero_expert_num is not None and self.zero_expert_num > 0:
 | 
			
		||||
                assert isinstance(final_hidden_states, torch.Tensor)
 | 
			
		||||
                return reduce_output(final_hidden_states) + zero_expert_result
 | 
			
		||||
            else:
 | 
			
		||||
                return reduce_output(final_hidden_states)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def make_expert_params_mapping(
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,7 @@ from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
import vllm.envs as envs
 | 
			
		||||
from vllm.model_executor.custom_op import CustomOp
 | 
			
		||||
@ -375,3 +376,20 @@ class PolyNorm(CustomOp):
 | 
			
		||||
        x: torch.Tensor,
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LayerNorm(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Layer Normalization.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, dim: int, eps: float = 1e-6):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.dim = dim
 | 
			
		||||
        self.eps = eps
 | 
			
		||||
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
 | 
			
		||||
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor):
 | 
			
		||||
        return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
 | 
			
		||||
                            self.eps).type_as(x)
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,9 @@ class MLAModules:
 | 
			
		||||
    q_a_layernorm: Optional[torch.nn.Module]
 | 
			
		||||
    q_b_proj: Optional[torch.nn.Module]
 | 
			
		||||
    q_proj: Optional[torch.nn.Module]
 | 
			
		||||
    indexer: Optional[torch.nn.Module]
 | 
			
		||||
    is_sparse: bool
 | 
			
		||||
    topk_indices_buffer: Optional[torch.Tensor]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@CustomOp.register("multi_head_latent_attention")
 | 
			
		||||
@ -76,6 +79,13 @@ class MultiHeadLatentAttention(CustomOp):
 | 
			
		||||
        self.kv_b_proj = mla_modules.kv_b_proj
 | 
			
		||||
        self.rotary_emb = mla_modules.rotary_emb
 | 
			
		||||
        self.o_proj = mla_modules.o_proj
 | 
			
		||||
        self.indexer = mla_modules.indexer
 | 
			
		||||
        self.is_sparse = mla_modules.is_sparse
 | 
			
		||||
 | 
			
		||||
        if self.indexer is not None:
 | 
			
		||||
            assert hasattr(self.indexer, "topk_tokens")
 | 
			
		||||
            self.topk_tokens = self.indexer.topk_tokens
 | 
			
		||||
            self.topk_indices_buffer = mla_modules.topk_indices_buffer
 | 
			
		||||
 | 
			
		||||
        # In the MLA backend, kv_cache includes both k_c and
 | 
			
		||||
        # pe (i.e. decoupled position embeddings). In particular,
 | 
			
		||||
@ -92,6 +102,7 @@ class MultiHeadLatentAttention(CustomOp):
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            prefix=f"{prefix}.attn",
 | 
			
		||||
            use_mla=True,
 | 
			
		||||
            use_sparse=mla_modules.is_sparse,
 | 
			
		||||
            # MLA Args
 | 
			
		||||
            q_lora_rank=self.q_lora_rank,
 | 
			
		||||
            kv_lora_rank=self.kv_lora_rank,
 | 
			
		||||
@ -100,6 +111,7 @@ class MultiHeadLatentAttention(CustomOp):
 | 
			
		||||
            qk_head_dim=self.qk_head_dim,
 | 
			
		||||
            v_head_dim=self.v_head_dim,
 | 
			
		||||
            kv_b_proj=self.kv_b_proj,
 | 
			
		||||
            indexer=self.indexer,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.prefix = prefix
 | 
			
		||||
@ -145,6 +157,10 @@ class MultiHeadLatentAttention(CustomOp):
 | 
			
		||||
        q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
 | 
			
		||||
            positions, q[..., self.qk_nope_head_dim:], k_pe)
 | 
			
		||||
 | 
			
		||||
        if self.indexer and self.is_sparse:
 | 
			
		||||
            _topk_indices = self.indexer(hidden_states, q_c, positions,
 | 
			
		||||
                                         self.rotary_emb)
 | 
			
		||||
 | 
			
		||||
        attn_out = self.mla_attn(
 | 
			
		||||
            q,
 | 
			
		||||
            kv_c_normed,
 | 
			
		||||
 | 
			
		||||
@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
 | 
			
		||||
    # On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
 | 
			
		||||
    # requantize the weight and input to the specific scale
 | 
			
		||||
    # at the same time.
 | 
			
		||||
    if is_deep_gemm_e8m0_used():
 | 
			
		||||
    should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
 | 
			
		||||
        layer.orig_dtype, layer.weight)
 | 
			
		||||
    if is_deep_gemm_e8m0_used() and should_use_deepgemm:
 | 
			
		||||
        block_sz = tuple(layer.weight_block_size)
 | 
			
		||||
        requant_weight_ue8m0_inplace(layer.weight.data,
 | 
			
		||||
                                     layer.weight_scale.data, block_sz)
 | 
			
		||||
    # SM90 Block FP8 CUTLASS requires row-major weight scales
 | 
			
		||||
    elif (current_platform.is_device_capability(90)
 | 
			
		||||
          and cutlass_block_fp8_supported
 | 
			
		||||
          and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
 | 
			
		||||
                                                     layer.weight)):
 | 
			
		||||
          and cutlass_block_fp8_supported and not should_use_deepgemm):
 | 
			
		||||
        layer.weight_scale = torch.nn.Parameter(
 | 
			
		||||
            layer.weight_scale.data.T.contiguous(), requires_grad=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
 | 
			
		||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
 | 
			
		||||
from transformers.models.aria.processing_aria import AriaProcessor
 | 
			
		||||
 | 
			
		||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
 | 
			
		||||
from vllm.config import QuantizationConfig, VllmConfig
 | 
			
		||||
from vllm.distributed import get_tensor_model_parallel_rank
 | 
			
		||||
from vllm.model_executor.layers.activation import get_act_fn
 | 
			
		||||
from vllm.model_executor.layers.fused_moe import FusedMoE
 | 
			
		||||
@ -298,14 +298,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
 | 
			
		||||
    Experts (MoE) Layer.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: AriaTextConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(config, cache_config, quant_config, prefix)
 | 
			
		||||
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
 | 
			
		||||
        super().__init__(vllm_config, prefix)
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.mlp = AriaTextMoELayer(config,
 | 
			
		||||
                                    quant_config=quant_config,
 | 
			
		||||
                                    prefix=f"{prefix}.mlp")
 | 
			
		||||
 | 
			
		||||
@ -346,8 +346,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
 | 
			
		||||
            block_size=1,
 | 
			
		||||
            num_kv_heads=model_config.get_num_kv_heads(parallel_config),
 | 
			
		||||
            head_size=model_config.get_head_size(),
 | 
			
		||||
            dtype=kv_cache_dtype,
 | 
			
		||||
            use_mla=model_config.use_mla).page_size_bytes
 | 
			
		||||
            dtype=kv_cache_dtype).page_size_bytes
 | 
			
		||||
 | 
			
		||||
        model_cls, _ = ModelRegistry.resolve_model_cls(
 | 
			
		||||
            model_config.architecture,
 | 
			
		||||
@ -401,6 +400,31 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
 | 
			
		||||
                "exactly equal.", mamba_padding_pct)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
 | 
			
		||||
        """
 | 
			
		||||
        hf_config = vllm_config.model_config.hf_config
 | 
			
		||||
 | 
			
		||||
        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
 | 
			
		||||
        is_v32 = hasattr(hf_config, "index_topk")
 | 
			
		||||
        assert is_v32
 | 
			
		||||
 | 
			
		||||
        # For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
 | 
			
		||||
        #   "auto")
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        if cache_config.cache_dtype == "auto" or \
 | 
			
		||||
            cache_config.cache_dtype.startswith("fp8"):
 | 
			
		||||
            cache_config.cache_dtype = "fp8_ds_mla"
 | 
			
		||||
            logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
 | 
			
		||||
        if cache_config.cache_dtype == "bfloat16":
 | 
			
		||||
            cache_config.cache_dtype = "auto"
 | 
			
		||||
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
 | 
			
		||||
    "GteModel": SnowflakeGteNewModelConfig,
 | 
			
		||||
    "GteNewModel": GteNewModelConfig,
 | 
			
		||||
@ -417,4 +441,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
 | 
			
		||||
    "MambaForCausalLM": MambaModelConfig,
 | 
			
		||||
    "Mamba2ForCausalLM": MambaModelConfig,
 | 
			
		||||
    "FalconMambaForCausalLM": MambaModelConfig,
 | 
			
		||||
    "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -53,8 +53,20 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
 | 
			
		||||
        self.eh_proj = nn.Linear(config.hidden_size * 2,
 | 
			
		||||
                                 config.hidden_size,
 | 
			
		||||
                                 bias=False)
 | 
			
		||||
 | 
			
		||||
        self.is_v32 = hasattr(config, "index_topk")
 | 
			
		||||
        if self.is_v32:
 | 
			
		||||
            topk_tokens = config.index_topk
 | 
			
		||||
            topk_indices_buffer = torch.empty(
 | 
			
		||||
                vllm_config.scheduler_config.max_num_batched_tokens,
 | 
			
		||||
                topk_tokens,
 | 
			
		||||
                dtype=torch.int32,
 | 
			
		||||
                device="cuda")
 | 
			
		||||
        else:
 | 
			
		||||
            topk_indices_buffer = None
 | 
			
		||||
        self.shared_head = SharedHead(config=config, quant_config=quant_config)
 | 
			
		||||
        self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
 | 
			
		||||
        self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
 | 
			
		||||
                                                topk_indices_buffer)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -32,17 +32,22 @@ import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
from transformers import DeepseekV2Config, DeepseekV3Config
 | 
			
		||||
 | 
			
		||||
import vllm.envs as envs
 | 
			
		||||
from vllm.attention import Attention
 | 
			
		||||
from vllm.attention.backends.abstract import AttentionBackend
 | 
			
		||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
 | 
			
		||||
from vllm.compilation.decorators import support_torch_compile
 | 
			
		||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
 | 
			
		||||
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
 | 
			
		||||
                         get_current_vllm_config)
 | 
			
		||||
from vllm.distributed import (get_ep_group, get_pp_group,
 | 
			
		||||
                              get_tensor_model_parallel_rank,
 | 
			
		||||
                              get_tensor_model_parallel_world_size,
 | 
			
		||||
                              tensor_model_parallel_all_gather)
 | 
			
		||||
from vllm.forward_context import get_forward_context
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.model_executor.layers.activation import SiluAndMul
 | 
			
		||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
 | 
			
		||||
from vllm.model_executor.layers.fused_moe import FusedMoE
 | 
			
		||||
from vllm.model_executor.layers.layernorm import RMSNorm
 | 
			
		||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
 | 
			
		||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
 | 
			
		||||
                                               MergedColumnParallelLinear,
 | 
			
		||||
                                               ReplicatedLinear,
 | 
			
		||||
@ -50,20 +55,35 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
 | 
			
		||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
 | 
			
		||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
 | 
			
		||||
from vllm.model_executor.layers.quantization import QuantizationConfig
 | 
			
		||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
 | 
			
		||||
    per_token_group_quant_fp8)
 | 
			
		||||
from vllm.model_executor.layers.rotary_embedding import get_rope
 | 
			
		||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
 | 
			
		||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
 | 
			
		||||
    ParallelLMHead, VocabParallelEmbedding)
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import (
 | 
			
		||||
    default_weight_loader, maybe_remap_kv_scale_name)
 | 
			
		||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
 | 
			
		||||
from vllm.platforms import current_platform
 | 
			
		||||
from vllm.sequence import IntermediateTensors
 | 
			
		||||
from vllm.utils import cdiv, direct_register_custom_op
 | 
			
		||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
 | 
			
		||||
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
 | 
			
		||||
                                                    DeepseekV32IndexerMetadata)
 | 
			
		||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
 | 
			
		||||
 | 
			
		||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
 | 
			
		||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
 | 
			
		||||
                    make_empty_intermediate_tensors_factory, make_layers,
 | 
			
		||||
                    maybe_prefix)
 | 
			
		||||
 | 
			
		||||
if current_platform.is_cuda_alike():
 | 
			
		||||
    from vllm import _custom_ops as ops
 | 
			
		||||
elif current_platform.is_xpu():
 | 
			
		||||
    from vllm._ipex_ops import ipex_ops as ops
 | 
			
		||||
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeepseekV2MLP(nn.Module):
 | 
			
		||||
 | 
			
		||||
@ -108,43 +128,6 @@ class DeepseekV2MLP(nn.Module):
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Chunk x along the num_tokens axis for sequence parallelism
 | 
			
		||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
 | 
			
		||||
# The output tensor can have a sequence length 0 at small input sequence lengths
 | 
			
		||||
# even though we explicitly pad to avoid this.
 | 
			
		||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    tp_size = get_tensor_model_parallel_world_size()
 | 
			
		||||
    tp_rank = get_tensor_model_parallel_rank()
 | 
			
		||||
 | 
			
		||||
    # all_gather needs the sequence length to be divisible by tp_size
 | 
			
		||||
    seq_len = x.size(0)
 | 
			
		||||
    remainder = seq_len % tp_size
 | 
			
		||||
    if remainder != 0:
 | 
			
		||||
        pad_len = tp_size - remainder
 | 
			
		||||
        x = nn.functional.pad(x, (0, 0, 0, pad_len))
 | 
			
		||||
 | 
			
		||||
    chunk = x.shape[0] // tp_size
 | 
			
		||||
    start = tp_rank * chunk
 | 
			
		||||
    return torch.narrow(x, 0, start, chunk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    tp_size = get_tensor_model_parallel_world_size()
 | 
			
		||||
    seq_len = cdiv(x.size(0), tp_size)
 | 
			
		||||
    shape = list(x.shape)
 | 
			
		||||
    shape[0] = seq_len
 | 
			
		||||
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
direct_register_custom_op(
 | 
			
		||||
    op_name="sequence_parallel_chunk",
 | 
			
		||||
    op_func=sequence_parallel_chunk,
 | 
			
		||||
    fake_impl=sequence_parallel_chunk_fake,
 | 
			
		||||
    tags=(torch.Tag.needs_fixed_stride_order, ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeepseekV2MoE(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
@ -166,20 +149,7 @@ class DeepseekV2MoE(nn.Module):
 | 
			
		||||
        self.n_routed_experts: int = config.n_routed_experts
 | 
			
		||||
        self.n_shared_experts: int = config.n_shared_experts
 | 
			
		||||
 | 
			
		||||
        # The all_reduce at the end of attention (during o_proj) means that
 | 
			
		||||
        # inputs are replicated across each rank of the tensor parallel group.
 | 
			
		||||
        # If using expert-parallelism with DeepEP All2All ops, replicated
 | 
			
		||||
        # tokens results in useless duplicate computation and communication.
 | 
			
		||||
        #
 | 
			
		||||
        # In this case, ensure the input to the experts is sequence parallel
 | 
			
		||||
        # to avoid the excess work.
 | 
			
		||||
        #
 | 
			
		||||
        # Not needed for pplx-kernels as it can handle duplicate input tokens.
 | 
			
		||||
        self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
 | 
			
		||||
                                     in ("deepep_high_throughput",
 | 
			
		||||
                                         "deepep_low_latency")
 | 
			
		||||
                                     and parallel_config.enable_expert_parallel
 | 
			
		||||
                                     and self.tp_size > 1)
 | 
			
		||||
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
 | 
			
		||||
 | 
			
		||||
        if config.hidden_act != "silu":
 | 
			
		||||
            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
 | 
			
		||||
@ -278,8 +248,7 @@ class DeepseekV2MoE(nn.Module):
 | 
			
		||||
        # TODO: We can replace the all_reduce at the end of attn with a
 | 
			
		||||
        # reduce_scatter instead of chunking here.
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            hidden_states = torch.ops.vllm.sequence_parallel_chunk(
 | 
			
		||||
                hidden_states)
 | 
			
		||||
            hidden_states = sequence_parallel_chunk(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # router_logits: (num_tokens, n_experts)
 | 
			
		||||
        router_logits, _ = self.gate(hidden_states)
 | 
			
		||||
@ -328,6 +297,7 @@ class DeepseekV2Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        config: Union[DeepseekV2Config, DeepseekV3Config],
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        num_heads: int,
 | 
			
		||||
@ -341,6 +311,7 @@ class DeepseekV2Attention(nn.Module):
 | 
			
		||||
        max_position_embeddings: int = 8192,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        topk_indices_buffer: Optional[torch.Tensor] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@ -358,6 +329,8 @@ class DeepseekV2Attention(nn.Module):
 | 
			
		||||
        self.scaling = self.qk_head_dim**-0.5
 | 
			
		||||
        self.rope_theta = rope_theta
 | 
			
		||||
        self.max_position_embeddings = max_position_embeddings
 | 
			
		||||
        assert topk_indices_buffer is None, "topk_indices_buffer is not \
 | 
			
		||||
        supported for DeepseekV2Attention"
 | 
			
		||||
 | 
			
		||||
        if self.q_lora_rank is not None:
 | 
			
		||||
            self.q_a_proj = ReplicatedLinear(self.hidden_size,
 | 
			
		||||
@ -470,6 +443,390 @@ class DeepseekV2Attention(nn.Module):
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
 | 
			
		||||
                 cache_config: CacheConfig):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.kv_cache = [torch.tensor([])]
 | 
			
		||||
        self.head_dim = head_dim
 | 
			
		||||
        self.prefix = prefix
 | 
			
		||||
        self.cache_config = cache_config
 | 
			
		||||
        self.dtype = dtype
 | 
			
		||||
        compilation_config = get_current_vllm_config().compilation_config
 | 
			
		||||
        if prefix in compilation_config.static_forward_context:
 | 
			
		||||
            raise ValueError(f"Duplicate layer name: {prefix}")
 | 
			
		||||
        compilation_config.static_forward_context[prefix] = self
 | 
			
		||||
 | 
			
		||||
    def get_kv_cache_spec(self) -> KVCacheSpec:
 | 
			
		||||
        return MLAAttentionSpec(  # Only has one vector instead of K + V
 | 
			
		||||
            block_size=self.cache_config.block_size,
 | 
			
		||||
            num_kv_heads=1,
 | 
			
		||||
            head_size=self.head_dim,
 | 
			
		||||
            dtype=self.dtype,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def get_attn_backend(self) -> AttentionBackend:
 | 
			
		||||
        return DeepseekV32IndexerBackend
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.inference_mode()
 | 
			
		||||
def cp_gather_indexer_k_quant_cache(
 | 
			
		||||
    kv_cache,  # [num_blocks, block_size, head_dim + 1]
 | 
			
		||||
    dst_value,  # [cu_seq_lens[-1], head_dim]
 | 
			
		||||
    dst_scale,  # [cu_seq_lens[-1], 4]
 | 
			
		||||
    block_table,  # [batch_size, num_blocks]
 | 
			
		||||
    cu_seq_lens,  # [batch_size + 1, ]
 | 
			
		||||
    batch_size,
 | 
			
		||||
):
 | 
			
		||||
    num_blocks, block_size, _ = kv_cache.shape
 | 
			
		||||
    head_dim = dst_value.shape[-1]
 | 
			
		||||
    kv_cache = kv_cache.view(num_blocks, -1)
 | 
			
		||||
 | 
			
		||||
    expected_value = []
 | 
			
		||||
    expected_scale = []
 | 
			
		||||
    for b in range(batch_size):
 | 
			
		||||
        s = cu_seq_lens[b + 1] - cu_seq_lens[b]
 | 
			
		||||
        if s == 0:
 | 
			
		||||
            continue
 | 
			
		||||
        tot = cdiv(s, block_size)
 | 
			
		||||
        blocks = block_table[b, :tot]
 | 
			
		||||
 | 
			
		||||
        value = []
 | 
			
		||||
        scale = []
 | 
			
		||||
        full_block = torch.arange(tot - 1,
 | 
			
		||||
                                  device=kv_cache.device,
 | 
			
		||||
                                  dtype=torch.int32)
 | 
			
		||||
        non_remaining_value = kv_cache[blocks[full_block], :block_size *
 | 
			
		||||
                                       head_dim].view(-1, head_dim)
 | 
			
		||||
        non_remaining_scale = kv_cache[blocks[full_block],
 | 
			
		||||
                                       block_size * head_dim:].view(-1, 4)
 | 
			
		||||
 | 
			
		||||
        remaining = s - (tot - 1) * block_size
 | 
			
		||||
 | 
			
		||||
        value = torch.cat([
 | 
			
		||||
            non_remaining_value,
 | 
			
		||||
            kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
 | 
			
		||||
        ],
 | 
			
		||||
                          dim=0)
 | 
			
		||||
        scale = torch.cat([
 | 
			
		||||
            non_remaining_scale,
 | 
			
		||||
            kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
 | 
			
		||||
                     remaining * 4].view(-1, 4)
 | 
			
		||||
        ],
 | 
			
		||||
                          dim=0)
 | 
			
		||||
 | 
			
		||||
        expected_value.append(value)
 | 
			
		||||
        expected_scale.append(scale)
 | 
			
		||||
 | 
			
		||||
    gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
 | 
			
		||||
    gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
 | 
			
		||||
    gather_value = gather_value.view(torch.float8_e4m3fn)
 | 
			
		||||
    gather_scale = gather_scale.view(torch.float32)
 | 
			
		||||
    dst_value.copy_(gather_value)
 | 
			
		||||
    dst_scale.copy_(gather_scale)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sparse_attn_indexer(
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    k_cache_prefix: str,
 | 
			
		||||
    kv_cache: torch.Tensor,
 | 
			
		||||
    q_fp8: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    weights: torch.Tensor,
 | 
			
		||||
    quant_block_size: int,
 | 
			
		||||
    scale_fmt: Optional[str],
 | 
			
		||||
    topk_tokens: int,
 | 
			
		||||
    head_dim: int,
 | 
			
		||||
    max_model_len: int,
 | 
			
		||||
    total_seq_lens: int,
 | 
			
		||||
    topk_indices_buffer: Optional[torch.Tensor],
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
    # careful! this will be None in dummy run
 | 
			
		||||
    attn_metadata = get_forward_context().attn_metadata
 | 
			
		||||
    # assert isinstance(attn_metadata, dict)
 | 
			
		||||
    if not isinstance(attn_metadata, dict):
 | 
			
		||||
        return sparse_attn_indexer_fake(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            k_cache_prefix,
 | 
			
		||||
            kv_cache,
 | 
			
		||||
            q_fp8,
 | 
			
		||||
            k,
 | 
			
		||||
            weights,
 | 
			
		||||
            quant_block_size,
 | 
			
		||||
            scale_fmt,
 | 
			
		||||
            topk_tokens,
 | 
			
		||||
            head_dim,
 | 
			
		||||
            max_model_len,
 | 
			
		||||
            total_seq_lens,
 | 
			
		||||
            topk_indices_buffer,
 | 
			
		||||
        )
 | 
			
		||||
    attn_metadata = attn_metadata[k_cache_prefix]
 | 
			
		||||
    assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
 | 
			
		||||
    slot_mapping = attn_metadata.slot_mapping
 | 
			
		||||
    has_decode = attn_metadata.num_decodes > 0
 | 
			
		||||
    has_prefill = attn_metadata.num_prefills > 0
 | 
			
		||||
    num_decode_tokens = attn_metadata.num_decode_tokens
 | 
			
		||||
 | 
			
		||||
    ops.indexer_k_quant_and_cache(
 | 
			
		||||
        k,
 | 
			
		||||
        kv_cache,
 | 
			
		||||
        slot_mapping,
 | 
			
		||||
        quant_block_size,
 | 
			
		||||
        scale_fmt,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    topk_indices_buffer[:hidden_states.shape[0]] = -1
 | 
			
		||||
    if has_prefill:
 | 
			
		||||
        prefill_metadata = attn_metadata.prefill
 | 
			
		||||
        for chunk in prefill_metadata.chunks:
 | 
			
		||||
            k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
 | 
			
		||||
                                device=k.device,
 | 
			
		||||
                                dtype=torch.float8_e4m3fn)
 | 
			
		||||
            k_scale = torch.empty([chunk.total_seq_lens, 1],
 | 
			
		||||
                                  device=k.device,
 | 
			
		||||
                                  dtype=torch.float32)
 | 
			
		||||
            cp_gather_indexer_k_quant_cache(
 | 
			
		||||
                kv_cache,
 | 
			
		||||
                k_fp8,
 | 
			
		||||
                k_scale,
 | 
			
		||||
                chunk.block_table,
 | 
			
		||||
                chunk.cu_seq_lens,
 | 
			
		||||
                chunk.num_reqs,
 | 
			
		||||
            )
 | 
			
		||||
            logits = fp8_mqa_logits(
 | 
			
		||||
                q_fp8[chunk.token_start:chunk.token_end],
 | 
			
		||||
                (k_fp8, k_scale),
 | 
			
		||||
                weights[chunk.token_start:chunk.token_end],
 | 
			
		||||
                chunk.cu_seqlen_ks,
 | 
			
		||||
                chunk.cu_seqlen_ke,
 | 
			
		||||
            )
 | 
			
		||||
            topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
 | 
			
		||||
                                       dim=-1)[1]
 | 
			
		||||
            topk_indices -= chunk.cu_seqlen_ks[:, None]
 | 
			
		||||
            mask_lo = topk_indices >= 0
 | 
			
		||||
            mask_hi = topk_indices - (chunk.cu_seqlen_ke -
 | 
			
		||||
                                      chunk.cu_seqlen_ks)[:, None] < 0
 | 
			
		||||
            mask = torch.full_like(topk_indices,
 | 
			
		||||
                                   False,
 | 
			
		||||
                                   dtype=torch.bool,
 | 
			
		||||
                                   device=topk_indices.device)
 | 
			
		||||
            mask = mask_lo & mask_hi
 | 
			
		||||
            topk_indices = topk_indices.masked_fill(~mask, -1)
 | 
			
		||||
            topk_indices_buffer[
 | 
			
		||||
                chunk.token_start:chunk.token_end, :topk_indices.
 | 
			
		||||
                shape[-1]] = topk_indices.to(dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
    if has_decode:
 | 
			
		||||
        decode_metadata = attn_metadata.decode
 | 
			
		||||
        # kv_cache size requirement [num_block, block_size, n_head, head_dim],
 | 
			
		||||
        # we only have [num_block, block_size, head_dim],
 | 
			
		||||
        kv_cache = kv_cache.unsqueeze(-2)
 | 
			
		||||
        decode_lens = decode_metadata.decode_lens
 | 
			
		||||
        if decode_metadata.requires_padding:
 | 
			
		||||
            # pad in edge case where we have short chunked prefill length <
 | 
			
		||||
            # decode_threshold since we unstrictly split
 | 
			
		||||
            # prefill and decode by decode_threshold
 | 
			
		||||
            # (currently set to 1 + speculative tokens)
 | 
			
		||||
            padded_q_fp8_decode_tokens = pack_seq_triton(
 | 
			
		||||
                q_fp8[:num_decode_tokens], decode_lens)
 | 
			
		||||
        else:
 | 
			
		||||
            padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
 | 
			
		||||
                decode_lens.shape[0], -1, *q_fp8.shape[1:])
 | 
			
		||||
        # TODO: move and optimize below logic with triton kernels
 | 
			
		||||
        batch_size = padded_q_fp8_decode_tokens.shape[0]
 | 
			
		||||
        next_n = padded_q_fp8_decode_tokens.shape[1]
 | 
			
		||||
        assert batch_size == decode_metadata.seq_lens.shape[0]
 | 
			
		||||
        num_padded_tokens = batch_size * next_n
 | 
			
		||||
        logits = fp8_paged_mqa_logits(
 | 
			
		||||
            padded_q_fp8_decode_tokens,
 | 
			
		||||
            kv_cache,
 | 
			
		||||
            weights[:num_padded_tokens],
 | 
			
		||||
            decode_metadata.seq_lens,
 | 
			
		||||
            decode_metadata.block_table,
 | 
			
		||||
            decode_metadata.schedule_metadata,
 | 
			
		||||
            max_model_len=max_model_len,
 | 
			
		||||
        )
 | 
			
		||||
        # padded query len
 | 
			
		||||
        current_device = padded_q_fp8_decode_tokens.device
 | 
			
		||||
        padded_num_tokens = batch_size * next_n
 | 
			
		||||
        positions = torch.arange(max_model_len,
 | 
			
		||||
                                 device=current_device).unsqueeze(0).expand(
 | 
			
		||||
                                     batch_size * next_n, -1)
 | 
			
		||||
        row_indices = torch.arange(padded_num_tokens,
 | 
			
		||||
                                   device=current_device) // next_n
 | 
			
		||||
        next_n_offset = torch.arange(
 | 
			
		||||
            padded_num_tokens,
 | 
			
		||||
            device=padded_q_fp8_decode_tokens.device) % next_n
 | 
			
		||||
        index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
 | 
			
		||||
                         next_n_offset).unsqueeze(1)
 | 
			
		||||
        # index_end_pos: [B * N, 1]
 | 
			
		||||
        mask = positions <= index_end_pos
 | 
			
		||||
        # mask: [B * N, L]
 | 
			
		||||
        logits = logits.masked_fill(~mask, float('-inf'))
 | 
			
		||||
        topk_indices = logits.topk(topk_tokens,
 | 
			
		||||
                                   dim=-1)[1].to(torch.int32)  # [B * N, K]
 | 
			
		||||
        # ensure we don't set indices for the top k
 | 
			
		||||
        # that is out of range(masked already)
 | 
			
		||||
        # this will happen if context length is shorter than K
 | 
			
		||||
        topk_indices[topk_indices > index_end_pos] = -1
 | 
			
		||||
        if decode_metadata.requires_padding:
 | 
			
		||||
            # if padded, we need to unpack
 | 
			
		||||
            # the topk indices removing padded tokens
 | 
			
		||||
            topk_indices = unpack_seq_triton(
 | 
			
		||||
                topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
 | 
			
		||||
                decode_lens)
 | 
			
		||||
        topk_indices_buffer[:num_decode_tokens, :topk_indices.
 | 
			
		||||
                            shape[-1]] = topk_indices.to(dtype=torch.int32)
 | 
			
		||||
 | 
			
		||||
    return topk_indices_buffer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sparse_attn_indexer_fake(
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
    k_cache_prefix: str,
 | 
			
		||||
    kv_cache: torch.Tensor,
 | 
			
		||||
    q_fp8: torch.Tensor,
 | 
			
		||||
    k: torch.Tensor,
 | 
			
		||||
    weights: torch.Tensor,
 | 
			
		||||
    quant_block_size: int,
 | 
			
		||||
    scale_fmt: Optional[str],
 | 
			
		||||
    topk_tokens: int,
 | 
			
		||||
    head_dim: int,
 | 
			
		||||
    max_model_len: int,
 | 
			
		||||
    total_seq_lens: int,
 | 
			
		||||
    topk_indices_buffer: Optional[torch.Tensor],
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    # profile run
 | 
			
		||||
    # NOTE(Chen): create the max possible flattened_kv. So that
 | 
			
		||||
    # profile_run can get correct memory usage.
 | 
			
		||||
    _flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
 | 
			
		||||
                                device=k.device,
 | 
			
		||||
                                dtype=torch.uint8)
 | 
			
		||||
    _k_fp8 = _flattened_kv[..., :head_dim].view(
 | 
			
		||||
        torch.float8_e4m3fn).contiguous()
 | 
			
		||||
    _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
 | 
			
		||||
    return topk_indices_buffer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
direct_register_custom_op(
 | 
			
		||||
    op_name="sparse_attn_indexer",
 | 
			
		||||
    op_func=sparse_attn_indexer,
 | 
			
		||||
    mutates_args=["topk_indices_buffer"],
 | 
			
		||||
    fake_impl=sparse_attn_indexer_fake,
 | 
			
		||||
    dispatch_key=current_platform.dispatch_key,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Indexer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vllm_config: VllmConfig,
 | 
			
		||||
                 config: Union[DeepseekV2Config, DeepseekV3Config],
 | 
			
		||||
                 hidden_size: int,
 | 
			
		||||
                 q_lora_rank: int,
 | 
			
		||||
                 quant_config: Optional[QuantizationConfig],
 | 
			
		||||
                 cache_config: Optional[CacheConfig],
 | 
			
		||||
                 topk_indices_buffer: Optional[torch.Tensor],
 | 
			
		||||
                 prefix: str = ""):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.vllm_config = vllm_config
 | 
			
		||||
        self.config = config
 | 
			
		||||
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
 | 
			
		||||
        self.topk_tokens = config.index_topk
 | 
			
		||||
        self.n_head = config.index_n_heads  # 64
 | 
			
		||||
        self.head_dim = config.index_head_dim  # 128
 | 
			
		||||
        self.rope_dim = config.qk_rope_head_dim  # 64
 | 
			
		||||
        self.q_lora_rank = q_lora_rank  # 1536
 | 
			
		||||
        # no tensor parallel, just replicated
 | 
			
		||||
        self.wq_b = ReplicatedLinear(self.q_lora_rank,
 | 
			
		||||
                                     self.head_dim * self.n_head,
 | 
			
		||||
                                     bias=False,
 | 
			
		||||
                                     quant_config=quant_config,
 | 
			
		||||
                                     prefix=f"{prefix}.wq_b")
 | 
			
		||||
        self.wk = ReplicatedLinear(hidden_size,
 | 
			
		||||
                                   self.head_dim,
 | 
			
		||||
                                   bias=False,
 | 
			
		||||
                                   quant_config=quant_config,
 | 
			
		||||
                                   prefix=f"{prefix}.wk")
 | 
			
		||||
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
 | 
			
		||||
        self.weights_proj = ReplicatedLinear(hidden_size,
 | 
			
		||||
                                             self.n_head,
 | 
			
		||||
                                             quant_config=None,
 | 
			
		||||
                                             prefix=f"{prefix}.weights_proj")
 | 
			
		||||
        self.softmax_scale = self.head_dim**-0.5
 | 
			
		||||
 | 
			
		||||
        self.scale_fmt = "ue8m0"
 | 
			
		||||
        self.quant_block_size = 128  # TODO: get from config
 | 
			
		||||
        self.topk_indices_buffer = topk_indices_buffer
 | 
			
		||||
 | 
			
		||||
        # NOTE: (zyongye) we use fp8 naive cache,
 | 
			
		||||
        #       where we store value in fp8 and scale in fp32
 | 
			
		||||
        #       per self.quant_block_size element
 | 
			
		||||
        self.k_cache = DeepseekV32IndexerCache(
 | 
			
		||||
            head_dim=self.head_dim +
 | 
			
		||||
            self.head_dim // self.quant_block_size * 4,
 | 
			
		||||
            dtype=torch.uint8,
 | 
			
		||||
            prefix=f"{prefix}.k_cache",
 | 
			
		||||
            cache_config=cache_config)
 | 
			
		||||
        self.max_model_len = vllm_config.model_config.max_model_len
 | 
			
		||||
        self.prefix = prefix
 | 
			
		||||
        from vllm.v1.attention.backends.mla.indexer import (
 | 
			
		||||
            get_max_prefill_buffer_size)
 | 
			
		||||
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
 | 
			
		||||
                rotary_emb) -> torch.Tensor:
 | 
			
		||||
        q, _ = self.wq_b(qr)
 | 
			
		||||
        q = q.view(-1, self.n_head, self.head_dim)
 | 
			
		||||
        q_pe, q_nope = torch.split(
 | 
			
		||||
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
 | 
			
		||||
 | 
			
		||||
        k, _ = self.wk(hidden_states)
 | 
			
		||||
        k = self.k_norm(k)
 | 
			
		||||
        k_pe, k_nope = torch.split(
 | 
			
		||||
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
 | 
			
		||||
 | 
			
		||||
        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
 | 
			
		||||
        q = torch.cat([q_pe, q_nope], dim=-1)
 | 
			
		||||
        k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
 | 
			
		||||
 | 
			
		||||
        # we only quant q here since k quant is fused with cache insertion
 | 
			
		||||
        q = q.view(-1, self.head_dim)
 | 
			
		||||
        q_fp8, q_scale = per_token_group_quant_fp8(q,
 | 
			
		||||
                                                   self.quant_block_size,
 | 
			
		||||
                                                   column_major_scales=False,
 | 
			
		||||
                                                   use_ue8m0=self.scale_fmt
 | 
			
		||||
                                                   is not None)
 | 
			
		||||
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
 | 
			
		||||
        q_scale = q_scale.view(-1, self.n_head, 1)
 | 
			
		||||
 | 
			
		||||
        weights, _ = self.weights_proj(hidden_states)
 | 
			
		||||
        weights = weights.unsqueeze(
 | 
			
		||||
            -1) * q_scale * self.softmax_scale * self.n_head**-0.5
 | 
			
		||||
        weights = weights.squeeze(-1)
 | 
			
		||||
 | 
			
		||||
        return torch.ops.vllm.sparse_attn_indexer(
 | 
			
		||||
            hidden_states,
 | 
			
		||||
            self.k_cache.prefix,
 | 
			
		||||
            self.k_cache.kv_cache[0],
 | 
			
		||||
            q_fp8,
 | 
			
		||||
            k,
 | 
			
		||||
            weights,
 | 
			
		||||
            self.quant_block_size,
 | 
			
		||||
            self.scale_fmt,
 | 
			
		||||
            self.topk_tokens,
 | 
			
		||||
            self.head_dim,
 | 
			
		||||
            self.max_model_len,
 | 
			
		||||
            self.max_total_seq_len,
 | 
			
		||||
            self.topk_indices_buffer,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DeepseekV2MLAAttention(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
 | 
			
		||||
@ -481,6 +838,7 @@ class DeepseekV2MLAAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        config: Union[DeepseekV2Config, DeepseekV3Config],
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        num_heads: int,
 | 
			
		||||
@ -495,6 +853,7 @@ class DeepseekV2MLAAttention(nn.Module):
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        topk_indices_buffer: Optional[torch.Tensor] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
@ -575,6 +934,15 @@ class DeepseekV2MLAAttention(nn.Module):
 | 
			
		||||
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
 | 
			
		||||
            self.scaling = self.scaling * mscale * mscale
 | 
			
		||||
 | 
			
		||||
        self.is_v32 = hasattr(config, "index_topk")
 | 
			
		||||
 | 
			
		||||
        if self.is_v32:
 | 
			
		||||
            self.indexer = Indexer(vllm_config, config, hidden_size,
 | 
			
		||||
                                   q_lora_rank, quant_config, cache_config,
 | 
			
		||||
                                   topk_indices_buffer, f"{prefix}.indexer")
 | 
			
		||||
        else:
 | 
			
		||||
            self.indexer = None
 | 
			
		||||
 | 
			
		||||
        mla_modules = MLAModules(
 | 
			
		||||
            kv_a_layernorm=self.kv_a_layernorm,
 | 
			
		||||
            kv_b_proj=self.kv_b_proj,
 | 
			
		||||
@ -588,7 +956,11 @@ class DeepseekV2MLAAttention(nn.Module):
 | 
			
		||||
            if self.q_lora_rank is not None else None,
 | 
			
		||||
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
 | 
			
		||||
            q_proj=self.q_proj if self.q_lora_rank is None else None,
 | 
			
		||||
            indexer=self.indexer,
 | 
			
		||||
            is_sparse=self.is_v32,
 | 
			
		||||
            topk_indices_buffer=topk_indices_buffer,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.mla_attn = MultiHeadLatentAttention(
 | 
			
		||||
            self.hidden_size,
 | 
			
		||||
            self.num_local_heads,
 | 
			
		||||
@ -614,7 +986,10 @@ class DeepseekV2MLAAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
class DeepseekV2DecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vllm_config: VllmConfig,
 | 
			
		||||
                 prefix: str,
 | 
			
		||||
                 topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
@ -637,6 +1012,7 @@ class DeepseekV2DecoderLayer(nn.Module):
 | 
			
		||||
        else:
 | 
			
		||||
            attn_cls = DeepseekV2Attention
 | 
			
		||||
        self.self_attn = attn_cls(
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            config=config,
 | 
			
		||||
            hidden_size=self.hidden_size,
 | 
			
		||||
            num_heads=config.num_attention_heads,
 | 
			
		||||
@ -652,6 +1028,7 @@ class DeepseekV2DecoderLayer(nn.Module):
 | 
			
		||||
            cache_config=cache_config,
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            prefix=f"{prefix}.self_attn",
 | 
			
		||||
            topk_indices_buffer=topk_indices_buffer,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if (config.n_routed_experts is not None
 | 
			
		||||
@ -735,6 +1112,16 @@ class DeepseekV2Model(nn.Module):
 | 
			
		||||
        self.config = config
 | 
			
		||||
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.is_v32 = hasattr(config, "index_topk")
 | 
			
		||||
        if self.is_v32:
 | 
			
		||||
            topk_tokens = config.index_topk
 | 
			
		||||
            topk_indices_buffer = torch.empty(
 | 
			
		||||
                vllm_config.scheduler_config.max_num_batched_tokens,
 | 
			
		||||
                topk_tokens,
 | 
			
		||||
                dtype=torch.int32,
 | 
			
		||||
                device="cuda")
 | 
			
		||||
        else:
 | 
			
		||||
            topk_indices_buffer = None
 | 
			
		||||
 | 
			
		||||
        if get_pp_group().is_first_rank:
 | 
			
		||||
            self.embed_tokens = VocabParallelEmbedding(
 | 
			
		||||
@ -747,7 +1134,8 @@ class DeepseekV2Model(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.start_layer, self.end_layer, self.layers = make_layers(
 | 
			
		||||
            config.num_hidden_layers,
 | 
			
		||||
            lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
 | 
			
		||||
            lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
 | 
			
		||||
                                                  topk_indices_buffer),
 | 
			
		||||
            prefix=f"{prefix}.layers")
 | 
			
		||||
 | 
			
		||||
        if get_pp_group().is_last_rank:
 | 
			
		||||
 | 
			
		||||
@ -29,10 +29,9 @@ import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
from transformers import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
 | 
			
		||||
from vllm.config import VllmConfig
 | 
			
		||||
from vllm.model_executor.layers.layernorm import RMSNorm
 | 
			
		||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
 | 
			
		||||
from vllm.model_executor.layers.quantization import QuantizationConfig
 | 
			
		||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
 | 
			
		||||
    ParallelLMHead, VocabParallelEmbedding)
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 | 
			
		||||
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: PretrainedConfig,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        prefix: str,
 | 
			
		||||
        model_config: ModelConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
 | 
			
		||||
        self.mtp_emb_norm = RMSNorm(config.hidden_size,
 | 
			
		||||
                                    eps=config.rms_norm_eps)
 | 
			
		||||
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
 | 
			
		||||
        self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
 | 
			
		||||
                                         config.hidden_size,
 | 
			
		||||
                                         bias=False)
 | 
			
		||||
        self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
 | 
			
		||||
                                           prefix)
 | 
			
		||||
        self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
 | 
			
		||||
        self.layers = torch.nn.ModuleDict({
 | 
			
		||||
            str(idx):
 | 
			
		||||
            ErnieMultiTokenPredictorLayer(
 | 
			
		||||
                config,
 | 
			
		||||
                vllm_config,
 | 
			
		||||
                f"{prefix}.layers.{idx}",
 | 
			
		||||
                model_config=vllm_config.model_config,
 | 
			
		||||
                cache_config=vllm_config.cache_config,
 | 
			
		||||
            )
 | 
			
		||||
            for idx in range(self.mtp_start_layer_idx,
 | 
			
		||||
                             self.mtp_start_layer_idx + self.num_mtp_layers)
 | 
			
		||||
 | 
			
		||||
@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
class Glm4DecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: Glm4Config,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vllm_config: VllmConfig,
 | 
			
		||||
                 prefix: str = "",
 | 
			
		||||
                 config: Optional[Glm4Config] = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = config or vllm_config.model_config.hf_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
        rope_theta = getattr(config, "rope_theta", 1000000)
 | 
			
		||||
        rope_scaling = getattr(config, "rope_scaling", None)
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
 | 
			
		||||
from vllm.config import CacheConfig, VllmConfig
 | 
			
		||||
from vllm.distributed import (get_ep_group, get_pp_group,
 | 
			
		||||
                              get_tensor_model_parallel_rank,
 | 
			
		||||
                              get_tensor_model_parallel_world_size)
 | 
			
		||||
                              get_tensor_model_parallel_world_size,
 | 
			
		||||
                              tensor_model_parallel_all_gather)
 | 
			
		||||
from vllm.model_executor.layers.fused_moe import FusedMoE
 | 
			
		||||
from vllm.model_executor.layers.layernorm import RMSNorm
 | 
			
		||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
 | 
			
		||||
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
 | 
			
		||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
 | 
			
		||||
    ParallelLMHead, VocabParallelEmbedding)
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 | 
			
		||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
 | 
			
		||||
from vllm.sequence import IntermediateTensors
 | 
			
		||||
from vllm.utils import cdiv
 | 
			
		||||
 | 
			
		||||
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: GptOssConfig,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        layer_idx: int,
 | 
			
		||||
        quant_config: QuantizationConfig,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        parallel_config = vllm_config.parallel_config
 | 
			
		||||
 | 
			
		||||
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
 | 
			
		||||
 | 
			
		||||
        self.layer_idx = layer_idx
 | 
			
		||||
        self.num_experts = config.num_local_experts
 | 
			
		||||
        self.experts_per_token = config.num_experts_per_tok
 | 
			
		||||
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
 | 
			
		||||
                                prefix=f"{prefix}.experts",
 | 
			
		||||
                                apply_router_weight_on_input=False,
 | 
			
		||||
                                has_bias=True,
 | 
			
		||||
                                activation="swigluoai")
 | 
			
		||||
                                activation="swigluoai",
 | 
			
		||||
                                is_sequence_parallel=self.is_sequence_parallel)
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        num_tokens = x.shape[0]
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            x = sequence_parallel_chunk(x)
 | 
			
		||||
 | 
			
		||||
        g = self.router(x)
 | 
			
		||||
        x = self.experts(hidden_states=x, router_logits=g)
 | 
			
		||||
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            x = tensor_model_parallel_all_gather(x.contiguous(), 0)
 | 
			
		||||
            x = x[:num_tokens]
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: GptOssConfig,
 | 
			
		||||
        cache_config: CacheConfig,
 | 
			
		||||
        quant_config: QuantizationConfig,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
 | 
			
		||||
        self.layer_idx = extract_layer_index(prefix)
 | 
			
		||||
        self.attn = OAIAttention(config,
 | 
			
		||||
                                 prefix=f"{prefix}.attn",
 | 
			
		||||
                                 cache_config=cache_config)
 | 
			
		||||
        self.mlp = MLPBlock(config,
 | 
			
		||||
        self.mlp = MLPBlock(vllm_config,
 | 
			
		||||
                            self.layer_idx,
 | 
			
		||||
                            quant_config=quant_config,
 | 
			
		||||
                            prefix=f"{prefix}.mlp")
 | 
			
		||||
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
 | 
			
		||||
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
 | 
			
		||||
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.config = vllm_config.model_config.hf_config
 | 
			
		||||
        self.cache_config = vllm_config.cache_config
 | 
			
		||||
        self.quant_config = vllm_config.quant_config
 | 
			
		||||
        self.parallel_config = vllm_config.parallel_config
 | 
			
		||||
        self.config.hidden_size = self.config.hidden_size
 | 
			
		||||
        self.embedding = VocabParallelEmbedding(
 | 
			
		||||
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
 | 
			
		||||
        self.start_layer, self.end_layer, self.layers = make_layers(
 | 
			
		||||
            self.config.num_hidden_layers,
 | 
			
		||||
            lambda prefix: TransformerBlock(
 | 
			
		||||
                self.config,
 | 
			
		||||
                cache_config=self.cache_config,
 | 
			
		||||
                quant_config=self.quant_config,
 | 
			
		||||
                vllm_config,
 | 
			
		||||
                prefix=prefix,
 | 
			
		||||
            ),
 | 
			
		||||
            prefix=f"{prefix}.layers",
 | 
			
		||||
 | 
			
		||||
@ -29,12 +29,13 @@ from typing import Any, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
from transformers.models.granitemoe import GraniteMoeConfig
 | 
			
		||||
 | 
			
		||||
from vllm.attention import Attention
 | 
			
		||||
from vllm.compilation.decorators import support_torch_compile
 | 
			
		||||
from vllm.config import CacheConfig, VllmConfig
 | 
			
		||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 | 
			
		||||
from vllm.distributed import (get_pp_group,
 | 
			
		||||
                              get_tensor_model_parallel_world_size,
 | 
			
		||||
                              tensor_model_parallel_all_gather)
 | 
			
		||||
from vllm.model_executor.layers.fused_moe import FusedMoE
 | 
			
		||||
from vllm.model_executor.layers.layernorm import RMSNorm
 | 
			
		||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
 | 
			
		||||
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
 | 
			
		||||
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import (
 | 
			
		||||
    default_weight_loader, maybe_remap_kv_scale_name)
 | 
			
		||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
 | 
			
		||||
from vllm.sequence import IntermediateTensors
 | 
			
		||||
 | 
			
		||||
from .interfaces import SupportsLoRA, SupportsPP
 | 
			
		||||
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
 | 
			
		||||
                 params_dtype: Optional[torch.dtype] = None,
 | 
			
		||||
                 quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
                 tp_size: Optional[int] = None,
 | 
			
		||||
                 is_sequence_parallel=False,
 | 
			
		||||
                 prefix: str = ""):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.hidden_size = hidden_size
 | 
			
		||||
        self.is_sequence_parallel = is_sequence_parallel
 | 
			
		||||
 | 
			
		||||
        # Gate always runs at half / full precision for now.
 | 
			
		||||
        self.gate = ReplicatedLinear(hidden_size,
 | 
			
		||||
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
 | 
			
		||||
                                renormalize=True,
 | 
			
		||||
                                quant_config=quant_config,
 | 
			
		||||
                                tp_size=tp_size,
 | 
			
		||||
                                prefix=f"{prefix}.experts")
 | 
			
		||||
                                prefix=f"{prefix}.experts",
 | 
			
		||||
                                is_sequence_parallel=self.is_sequence_parallel)
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        # NOTE: hidden_states can have either 1D or 2D shape.
 | 
			
		||||
        orig_shape = hidden_states.shape
 | 
			
		||||
        hidden_states = hidden_states.view(-1, self.hidden_size)
 | 
			
		||||
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            hidden_states = sequence_parallel_chunk(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # router_logits: (num_tokens, n_experts)
 | 
			
		||||
        router_logits, _ = self.gate(hidden_states)
 | 
			
		||||
        final_hidden_states = self.experts(hidden_states, router_logits)
 | 
			
		||||
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            final_hidden_states = tensor_model_parallel_all_gather(
 | 
			
		||||
                final_hidden_states, 0)
 | 
			
		||||
            num_tokens = orig_shape[0]
 | 
			
		||||
            final_hidden_states = final_hidden_states[:num_tokens]
 | 
			
		||||
 | 
			
		||||
        return final_hidden_states.view(orig_shape)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: GraniteMoeConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        parallel_config = vllm_config.parallel_config
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
        # Requires transformers > 4.32.0
 | 
			
		||||
        rope_theta = getattr(config, "rope_theta", 10000)
 | 
			
		||||
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
 | 
			
		||||
            hidden_size=config.hidden_size,
 | 
			
		||||
            intermediate_size=config.intermediate_size,
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
 | 
			
		||||
            prefix=f"{prefix}.block_sparse_moe")
 | 
			
		||||
 | 
			
		||||
        self.input_layernorm = RMSNorm(config.hidden_size,
 | 
			
		||||
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        lora_config = vllm_config.lora_config
 | 
			
		||||
 | 
			
		||||
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.start_layer, self.end_layer, self.layers = make_layers(
 | 
			
		||||
            config.num_hidden_layers,
 | 
			
		||||
            lambda prefix: GraniteMoeDecoderLayer(
 | 
			
		||||
                config, cache_config, quant_config=quant_config, prefix=prefix
 | 
			
		||||
            ),
 | 
			
		||||
            lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
 | 
			
		||||
            prefix=f"{prefix}.layers")
 | 
			
		||||
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
			
		||||
 | 
			
		||||
@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
 | 
			
		||||
        bias: bool = False,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        reduce_results: bool = True,
 | 
			
		||||
        disable_tp: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.gate_up_proj = MergedColumnParallelLinear(
 | 
			
		||||
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
 | 
			
		||||
            output_sizes=[intermediate_size] * 2,
 | 
			
		||||
            bias=bias,
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            disable_tp=disable_tp,
 | 
			
		||||
            prefix=f"{prefix}.gate_up_proj",
 | 
			
		||||
        )
 | 
			
		||||
        self.down_proj = RowParallelLinear(
 | 
			
		||||
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
 | 
			
		||||
            bias=bias,
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            reduce_results=reduce_results,
 | 
			
		||||
            disable_tp=disable_tp,
 | 
			
		||||
            prefix=f"{prefix}.down_proj",
 | 
			
		||||
        )
 | 
			
		||||
        if hidden_act != "silu":
 | 
			
		||||
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
class LlamaDecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: LlamaConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vllm_config: VllmConfig,
 | 
			
		||||
                 prefix: str = "",
 | 
			
		||||
                 config: Optional[LlamaConfig] = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = config or vllm_config.model_config.hf_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
        rope_theta = getattr(config, "rope_theta", 10000)
 | 
			
		||||
        rope_scaling = getattr(config, "rope_scaling", None)
 | 
			
		||||
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        lora_config = vllm_config.lora_config
 | 
			
		||||
 | 
			
		||||
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
 | 
			
		||||
            self.embed_tokens = PPMissingLayer()
 | 
			
		||||
        self.start_layer, self.end_layer, self.layers = make_layers(
 | 
			
		||||
            config.num_hidden_layers,
 | 
			
		||||
            lambda prefix: layer_type(config=config,
 | 
			
		||||
                                      cache_config=cache_config,
 | 
			
		||||
                                      quant_config=quant_config,
 | 
			
		||||
                                      prefix=prefix),
 | 
			
		||||
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
 | 
			
		||||
            prefix=f"{prefix}.layers",
 | 
			
		||||
        )
 | 
			
		||||
        if get_pp_group().is_last_rank:
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,8 @@ from vllm.attention import Attention
 | 
			
		||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
 | 
			
		||||
from vllm.compilation.decorators import support_torch_compile
 | 
			
		||||
from vllm.config import CacheConfig, VllmConfig
 | 
			
		||||
from vllm.distributed import get_tensor_model_parallel_world_size
 | 
			
		||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
 | 
			
		||||
                              tensor_model_parallel_all_gather)
 | 
			
		||||
from vllm.model_executor.layers.fused_moe import FusedMoE
 | 
			
		||||
from vllm.model_executor.layers.layernorm import RMSNorm
 | 
			
		||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
 | 
			
		||||
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
 | 
			
		||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import (
 | 
			
		||||
    default_weight_loader, maybe_remap_kv_scale_name)
 | 
			
		||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
 | 
			
		||||
 | 
			
		||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
 | 
			
		||||
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
 | 
			
		||||
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
 | 
			
		||||
        router_scores = torch.sigmoid(router_scores.float())
 | 
			
		||||
        return (router_scores, router_indices.to(torch.int32))
 | 
			
		||||
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 config: Llama4TextConfig,
 | 
			
		||||
                 quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
                 prefix: str = ""):
 | 
			
		||||
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        parallel_config = vllm_config.parallel_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.tp_size = get_tensor_model_parallel_world_size()
 | 
			
		||||
        self.top_k = config.num_experts_per_tok
 | 
			
		||||
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
 | 
			
		||||
 | 
			
		||||
        intermediate_size_moe = config.intermediate_size
 | 
			
		||||
        self.router = ReplicatedLinear(config.hidden_size,
 | 
			
		||||
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
 | 
			
		||||
            bias=False,
 | 
			
		||||
            prefix=f"{prefix}.shared_expert",
 | 
			
		||||
            reduce_results=False,
 | 
			
		||||
            disable_tp=self.is_sequence_parallel,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.experts = SharedFusedMoE(
 | 
			
		||||
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
 | 
			
		||||
            renormalize=False,
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            prefix=f"{prefix}.experts",
 | 
			
		||||
            is_sequence_parallel=self.is_sequence_parallel,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, hidden_states):
 | 
			
		||||
        num_tokens = hidden_states.shape[0]
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            hidden_states = sequence_parallel_chunk(hidden_states)
 | 
			
		||||
 | 
			
		||||
        router_logits, _ = self.router(hidden_states)
 | 
			
		||||
 | 
			
		||||
        shared_out, routed_out = self.experts(
 | 
			
		||||
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
 | 
			
		||||
        )
 | 
			
		||||
        experts_out = routed_out + shared_out
 | 
			
		||||
 | 
			
		||||
        if self.tp_size > 1:
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            experts_out = tensor_model_parallel_all_gather(experts_out, 0)
 | 
			
		||||
            experts_out = experts_out[:num_tokens]
 | 
			
		||||
        elif self.tp_size > 1:
 | 
			
		||||
            experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
 | 
			
		||||
                experts_out)
 | 
			
		||||
 | 
			
		||||
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
 | 
			
		||||
 | 
			
		||||
class Llama4DecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: Llama4TextConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vllm_config: VllmConfig,
 | 
			
		||||
                 prefix: str = "",
 | 
			
		||||
                 config: Optional[Llama4TextConfig] = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = config or vllm_config.model_config.hf_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.layer_idx = extract_layer_index(prefix)
 | 
			
		||||
        self.global_layer = config.no_rope_layers[self.layer_idx] == 0
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
 | 
			
		||||
            self.layer_idx + 1) % config.interleave_moe_layer_step == 0
 | 
			
		||||
        if is_moe_layer:
 | 
			
		||||
            self.feed_forward = Llama4MoE(
 | 
			
		||||
                config=config,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                vllm_config=vllm_config,
 | 
			
		||||
                prefix=f"{prefix}.feed_forward",
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.layers = nn.ModuleList([
 | 
			
		||||
            Llama4DecoderLayer(
 | 
			
		||||
                self.config,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                vllm_config=vllm_config,
 | 
			
		||||
                prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
 | 
			
		||||
                config=self.config,
 | 
			
		||||
            ) for i in range(self.config.num_hidden_layers)
 | 
			
		||||
        ])
 | 
			
		||||
        self.fc = torch.nn.Linear(self.config.hidden_size * 2,
 | 
			
		||||
 | 
			
		||||
@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: LlamaConfig,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        disable_input_layernorm: bool,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        config: Optional[LlamaConfig] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(config, prefix=prefix)
 | 
			
		||||
        super().__init__(vllm_config, prefix=prefix, config=config)
 | 
			
		||||
 | 
			
		||||
        # Skip the input_layernorm
 | 
			
		||||
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
 | 
			
		||||
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.layers = nn.ModuleList([
 | 
			
		||||
            LlamaDecoderLayer(
 | 
			
		||||
                self.config,
 | 
			
		||||
                vllm_config,
 | 
			
		||||
                i == 0,
 | 
			
		||||
                prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
 | 
			
		||||
                config=self.config,
 | 
			
		||||
            ) for i in range(self.config.num_hidden_layers)
 | 
			
		||||
        ])
 | 
			
		||||
        self.fc = torch.nn.Linear(self.config.hidden_size * 2,
 | 
			
		||||
 | 
			
		||||
@ -9,13 +9,11 @@ import torch.nn as nn
 | 
			
		||||
from transformers import LlamaConfig
 | 
			
		||||
 | 
			
		||||
from vllm.compilation.decorators import support_torch_compile
 | 
			
		||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
 | 
			
		||||
from vllm.config import VllmConfig, get_current_vllm_config
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.model_executor.layers.layernorm import RMSNorm
 | 
			
		||||
from vllm.model_executor.layers.linear import QKVParallelLinear
 | 
			
		||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
 | 
			
		||||
from vllm.model_executor.layers.quantization.base_config import (
 | 
			
		||||
    QuantizationConfig)
 | 
			
		||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
 | 
			
		||||
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 | 
			
		||||
@ -29,17 +27,14 @@ logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
class LlamaDecoderLayer(LlamaDecoderLayer):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: LlamaConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__(config,
 | 
			
		||||
                         cache_config=cache_config,
 | 
			
		||||
                         quant_config=quant_config,
 | 
			
		||||
                         prefix=prefix)
 | 
			
		||||
    def __init__(self,
 | 
			
		||||
                 vllm_config: VllmConfig,
 | 
			
		||||
                 prefix: str = "",
 | 
			
		||||
                 config: Optional[LlamaConfig] = None) -> None:
 | 
			
		||||
        super().__init__(vllm_config, prefix=prefix, config=config)
 | 
			
		||||
 | 
			
		||||
        config = config or vllm_config.model_config.hf_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        # override qkv
 | 
			
		||||
        self.self_attn.qkv_proj = QKVParallelLinear(
 | 
			
		||||
@ -127,9 +122,9 @@ class LlamaModel(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.layers = nn.ModuleList([
 | 
			
		||||
            LlamaDecoderLayer(
 | 
			
		||||
                config=self.config,
 | 
			
		||||
                cache_config=current_vllm_config.cache_config,
 | 
			
		||||
                current_vllm_config,
 | 
			
		||||
                prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
 | 
			
		||||
                config=self.config,
 | 
			
		||||
            )
 | 
			
		||||
        ])
 | 
			
		||||
        if hasattr(self.config, "target_hidden_size"):
 | 
			
		||||
 | 
			
		||||
@ -308,6 +308,7 @@ class FlashDecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        config: FlashConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
@ -329,6 +330,7 @@ class FlashDecoderLayer(nn.Module):
 | 
			
		||||
        # Dual attention structure
 | 
			
		||||
        self.self_attn = nn.ModuleList([
 | 
			
		||||
            DeepseekV2MLAAttention(
 | 
			
		||||
                vllm_config=vllm_config,
 | 
			
		||||
                config=config,
 | 
			
		||||
                hidden_size=self.hidden_size,
 | 
			
		||||
                num_heads=config.num_attention_heads,
 | 
			
		||||
@ -454,6 +456,7 @@ class FlashModel(nn.Module):
 | 
			
		||||
        self.start_layer, self.end_layer, self.layers = make_layers(
 | 
			
		||||
            config.num_hidden_layers,
 | 
			
		||||
            lambda prefix: FlashDecoderLayer(
 | 
			
		||||
                vllm_config,
 | 
			
		||||
                config,
 | 
			
		||||
                cache_config=cache_config,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
 | 
			
		||||
@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        use_data_parallel: bool = False,
 | 
			
		||||
        attn_backend: _Backend = _Backend.TORCH_SDPA,
 | 
			
		||||
        use_upstream_fa: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # Per attention head and per partition values.
 | 
			
		||||
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
 | 
			
		||||
                                      quant_config=quant_config,
 | 
			
		||||
                                      prefix=f"{prefix}.proj",
 | 
			
		||||
                                      disable_tp=use_data_parallel)
 | 
			
		||||
 | 
			
		||||
        # Detect attention implementation.
 | 
			
		||||
        self.attn_backend = get_vit_attn_backend(
 | 
			
		||||
            head_size=self.hidden_size_per_attention_head,
 | 
			
		||||
            dtype=torch.get_default_dtype())
 | 
			
		||||
        self.use_upstream_fa = False
 | 
			
		||||
        if self.attn_backend != _Backend.FLASH_ATTN and \
 | 
			
		||||
            check_upstream_fa_availability(
 | 
			
		||||
                torch.get_default_dtype()):
 | 
			
		||||
            self.attn_backend = _Backend.FLASH_ATTN
 | 
			
		||||
            self.use_upstream_fa = True
 | 
			
		||||
 | 
			
		||||
        if self.attn_backend not in {
 | 
			
		||||
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
 | 
			
		||||
                _Backend.ROCM_AITER_FA
 | 
			
		||||
        }:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Qwen2.5-VL does not support {self.attn_backend} backend now."
 | 
			
		||||
            )
 | 
			
		||||
        self.attn_backend = attn_backend
 | 
			
		||||
        self.use_upstream_fa = use_upstream_fa
 | 
			
		||||
        self.is_flash_attn_backend = self.attn_backend in {
 | 
			
		||||
            _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
 | 
			
		||||
        }
 | 
			
		||||
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        use_data_parallel: bool = False,
 | 
			
		||||
        attn_backend: _Backend = _Backend.TORCH_SDPA,
 | 
			
		||||
        use_upstream_fa: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        if norm_layer is None:
 | 
			
		||||
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
 | 
			
		||||
            projection_size=dim,
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            prefix=f"{prefix}.attn",
 | 
			
		||||
            use_data_parallel=use_data_parallel)
 | 
			
		||||
            use_data_parallel=use_data_parallel,
 | 
			
		||||
            attn_backend=attn_backend,
 | 
			
		||||
            use_upstream_fa=use_upstream_fa)
 | 
			
		||||
        self.mlp = Qwen2_5_VisionMLP(dim,
 | 
			
		||||
                                     mlp_hidden_dim,
 | 
			
		||||
                                     act_fn=act_fn,
 | 
			
		||||
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
 | 
			
		||||
        head_dim = self.hidden_size // self.num_heads
 | 
			
		||||
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
 | 
			
		||||
 | 
			
		||||
        use_upstream_fa = False
 | 
			
		||||
        self.attn_backend = get_vit_attn_backend(
 | 
			
		||||
            head_size=head_dim, dtype=torch.get_default_dtype())
 | 
			
		||||
        if self.attn_backend != _Backend.FLASH_ATTN and \
 | 
			
		||||
            check_upstream_fa_availability(
 | 
			
		||||
                torch.get_default_dtype()):
 | 
			
		||||
            self.attn_backend = _Backend.FLASH_ATTN
 | 
			
		||||
            use_upstream_fa = True
 | 
			
		||||
 | 
			
		||||
        if self.attn_backend not in {
 | 
			
		||||
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
 | 
			
		||||
                _Backend.ROCM_AITER_FA
 | 
			
		||||
        }:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Qwen2.5-VL does not support {self.attn_backend} backend now."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.blocks = nn.ModuleList([
 | 
			
		||||
            Qwen2_5_VisionBlock(dim=self.hidden_size,
 | 
			
		||||
                                num_heads=self.num_heads,
 | 
			
		||||
                                mlp_hidden_dim=vision_config.intermediate_size,
 | 
			
		||||
                                act_fn=get_act_and_mul_fn(
 | 
			
		||||
                                    vision_config.hidden_act),
 | 
			
		||||
                                norm_layer=norm_layer,
 | 
			
		||||
                                quant_config=quant_config,
 | 
			
		||||
                                prefix=f"{prefix}.blocks.{layer_idx}",
 | 
			
		||||
                                use_data_parallel=use_data_parallel)
 | 
			
		||||
            for layer_idx in range(depth)
 | 
			
		||||
            Qwen2_5_VisionBlock(
 | 
			
		||||
                dim=self.hidden_size,
 | 
			
		||||
                num_heads=self.num_heads,
 | 
			
		||||
                mlp_hidden_dim=vision_config.intermediate_size,
 | 
			
		||||
                act_fn=get_act_and_mul_fn(vision_config.hidden_act),
 | 
			
		||||
                norm_layer=norm_layer,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                prefix=f"{prefix}.blocks.{layer_idx}",
 | 
			
		||||
                use_data_parallel=use_data_parallel,
 | 
			
		||||
                attn_backend=self.attn_backend,
 | 
			
		||||
                use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
 | 
			
		||||
        ])
 | 
			
		||||
        self.merger = Qwen2_5_VisionPatchMerger(
 | 
			
		||||
            d_model=vision_config.out_hidden_size,
 | 
			
		||||
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
 | 
			
		||||
            prefix=f"{prefix}.merger",
 | 
			
		||||
            use_data_parallel=use_data_parallel,
 | 
			
		||||
        )
 | 
			
		||||
        self.attn_backend = get_vit_attn_backend(
 | 
			
		||||
            head_size=head_dim, dtype=torch.get_default_dtype())
 | 
			
		||||
        if self.attn_backend != _Backend.FLASH_ATTN and \
 | 
			
		||||
            check_upstream_fa_availability(
 | 
			
		||||
                torch.get_default_dtype()):
 | 
			
		||||
            self.attn_backend = _Backend.FLASH_ATTN
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dtype(self) -> torch.dtype:
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,7 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
# For profile run
 | 
			
		||||
_MAX_FRAMES_PER_VIDEO = 32
 | 
			
		||||
_MAX_FRAMES_PER_VIDEO = 14
 | 
			
		||||
 | 
			
		||||
# === Vision Inputs === #
 | 
			
		||||
 | 
			
		||||
@ -932,6 +932,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
 | 
			
		||||
        _, num_image_tokens = self._get_vision_info(
 | 
			
		||||
            image_width=image_width,
 | 
			
		||||
            image_height=image_height,
 | 
			
		||||
            num_frames=1,
 | 
			
		||||
            image_processor=image_processor,
 | 
			
		||||
        )
 | 
			
		||||
        return num_image_tokens
 | 
			
		||||
@ -956,6 +957,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
 | 
			
		||||
        max_image_size, _ = self._get_vision_info(
 | 
			
		||||
            image_width=9999999,
 | 
			
		||||
            image_height=9999999,
 | 
			
		||||
            num_frames=1,
 | 
			
		||||
            image_processor=None,
 | 
			
		||||
        )
 | 
			
		||||
        return max_image_size
 | 
			
		||||
@ -969,10 +971,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
 | 
			
		||||
            image_processor=None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _get_max_video_frames(self, max_tokens: int) -> int:
 | 
			
		||||
    def _get_max_video_frames(self,
 | 
			
		||||
                              max_tokens: int,
 | 
			
		||||
                              start_num_frames: int = 1) -> int:
 | 
			
		||||
        target_width, target_height = self.get_image_size_with_most_features()
 | 
			
		||||
 | 
			
		||||
        num_frames = 0
 | 
			
		||||
        num_frames = start_num_frames
 | 
			
		||||
 | 
			
		||||
        while True:
 | 
			
		||||
            next_num_frames = num_frames + 1
 | 
			
		||||
@ -994,12 +998,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
 | 
			
		||||
        self,
 | 
			
		||||
        seq_len: int,
 | 
			
		||||
        mm_counts: Mapping[str, int],
 | 
			
		||||
        max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
 | 
			
		||||
    ) -> int:
 | 
			
		||||
        max_videos = mm_counts.get("video", 0)
 | 
			
		||||
 | 
			
		||||
        max_total_frames = self._get_max_video_frames(seq_len)
 | 
			
		||||
        max_frames_per_video = min(max_total_frames // max(max_videos, 1),
 | 
			
		||||
                                   _MAX_FRAMES_PER_VIDEO)
 | 
			
		||||
                                   max_frames_per_video)
 | 
			
		||||
 | 
			
		||||
        return max(max_frames_per_video, 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -29,13 +29,13 @@ from typing import Any, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
from transformers import Qwen3MoeConfig
 | 
			
		||||
 | 
			
		||||
from vllm.attention import Attention
 | 
			
		||||
from vllm.compilation.decorators import support_torch_compile
 | 
			
		||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
 | 
			
		||||
from vllm.distributed import (get_ep_group, get_pp_group,
 | 
			
		||||
                              get_tensor_model_parallel_world_size)
 | 
			
		||||
                              get_tensor_model_parallel_world_size,
 | 
			
		||||
                              tensor_model_parallel_all_gather)
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.model_executor.layers.activation import SiluAndMul
 | 
			
		||||
from vllm.model_executor.layers.fused_moe import FusedMoE
 | 
			
		||||
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
 | 
			
		||||
    ParallelLMHead, VocabParallelEmbedding)
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import (
 | 
			
		||||
    default_weight_loader, maybe_remap_kv_scale_name)
 | 
			
		||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
 | 
			
		||||
from vllm.sequence import IntermediateTensors
 | 
			
		||||
 | 
			
		||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
 | 
			
		||||
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: Qwen3MoeConfig,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        enable_eplb: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_text_config
 | 
			
		||||
        parallel_config = vllm_config.parallel_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.tp_size = get_tensor_model_parallel_world_size()
 | 
			
		||||
 | 
			
		||||
        self.ep_group = get_ep_group().device_group
 | 
			
		||||
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
 | 
			
		||||
        self.ep_size = self.ep_group.size()
 | 
			
		||||
        self.n_routed_experts = config.num_experts
 | 
			
		||||
 | 
			
		||||
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
 | 
			
		||||
 | 
			
		||||
        if self.tp_size > config.num_experts:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Tensor parallel size {self.tp_size} is greater than "
 | 
			
		||||
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
 | 
			
		||||
        # Load balancing settings.
 | 
			
		||||
        vllm_config = get_current_vllm_config()
 | 
			
		||||
        eplb_config = vllm_config.parallel_config.eplb_config
 | 
			
		||||
        self.enable_eplb = enable_eplb
 | 
			
		||||
        self.enable_eplb = parallel_config.enable_eplb
 | 
			
		||||
 | 
			
		||||
        self.n_logical_experts = self.n_routed_experts
 | 
			
		||||
        self.n_redundant_experts = eplb_config.num_redundant_experts
 | 
			
		||||
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
 | 
			
		||||
                                quant_config=quant_config,
 | 
			
		||||
                                prefix=f"{prefix}.experts",
 | 
			
		||||
                                enable_eplb=self.enable_eplb,
 | 
			
		||||
                                num_redundant_experts=self.n_redundant_experts)
 | 
			
		||||
                                num_redundant_experts=self.n_redundant_experts,
 | 
			
		||||
                                is_sequence_parallel=self.is_sequence_parallel)
 | 
			
		||||
 | 
			
		||||
        self.gate = ReplicatedLinear(config.hidden_size,
 | 
			
		||||
                                     config.num_experts,
 | 
			
		||||
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
 | 
			
		||||
        assert hidden_states.dim(
 | 
			
		||||
        ) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
 | 
			
		||||
        is_input_1d = hidden_states.dim() == 1
 | 
			
		||||
        hidden_dim = hidden_states.shape[-1]
 | 
			
		||||
        num_tokens, hidden_dim = hidden_states.shape
 | 
			
		||||
        hidden_states = hidden_states.view(-1, hidden_dim)
 | 
			
		||||
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            hidden_states = sequence_parallel_chunk(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # router_logits: (num_tokens, n_experts)
 | 
			
		||||
        router_logits, _ = self.gate(hidden_states)
 | 
			
		||||
        final_hidden_states = self.experts(hidden_states=hidden_states,
 | 
			
		||||
                                           router_logits=router_logits)
 | 
			
		||||
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            final_hidden_states = tensor_model_parallel_all_gather(
 | 
			
		||||
                final_hidden_states, 0)
 | 
			
		||||
            final_hidden_states = final_hidden_states[:num_tokens]
 | 
			
		||||
 | 
			
		||||
        # return to 1d if input is 1d
 | 
			
		||||
        return final_hidden_states.squeeze(0) if is_input_1d else \
 | 
			
		||||
            final_hidden_states
 | 
			
		||||
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
 | 
			
		||||
 | 
			
		||||
class Qwen3MoeDecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: Qwen3MoeConfig,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        enable_eplb: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_text_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.hidden_size = config.hidden_size
 | 
			
		||||
        rope_theta = getattr(config, "rope_theta", 10000)
 | 
			
		||||
        rope_scaling = getattr(config, "rope_scaling", None)
 | 
			
		||||
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
 | 
			
		||||
        if (layer_idx not in mlp_only_layers) and (
 | 
			
		||||
                config.num_experts > 0 and
 | 
			
		||||
            (layer_idx + 1) % config.decoder_sparse_step == 0):
 | 
			
		||||
            self.mlp = Qwen3MoeSparseMoeBlock(config=config,
 | 
			
		||||
                                              quant_config=quant_config,
 | 
			
		||||
                                              prefix=f"{prefix}.mlp",
 | 
			
		||||
                                              enable_eplb=enable_eplb)
 | 
			
		||||
            self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
 | 
			
		||||
                                              prefix=f"{prefix}.mlp")
 | 
			
		||||
        else:
 | 
			
		||||
            self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
 | 
			
		||||
                                   intermediate_size=config.intermediate_size,
 | 
			
		||||
@ -361,11 +372,9 @@ class Qwen3MoeModel(nn.Module):
 | 
			
		||||
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config.get_text_config()
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        config = vllm_config.model_config.hf_text_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        parallel_config = vllm_config.parallel_config
 | 
			
		||||
        enable_eplb = parallel_config.enable_eplb
 | 
			
		||||
        eplb_config = parallel_config.eplb_config
 | 
			
		||||
        self.num_redundant_experts = eplb_config.num_redundant_experts
 | 
			
		||||
 | 
			
		||||
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
 | 
			
		||||
            prefix=f"{prefix}.embed_tokens")
 | 
			
		||||
        self.start_layer, self.end_layer, self.layers = make_layers(
 | 
			
		||||
            config.num_hidden_layers,
 | 
			
		||||
            lambda prefix: Qwen3MoeDecoderLayer(config=config,
 | 
			
		||||
                                                cache_config=cache_config,
 | 
			
		||||
                                                quant_config=quant_config,
 | 
			
		||||
                                                prefix=prefix,
 | 
			
		||||
                                                enable_eplb=enable_eplb),
 | 
			
		||||
            lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
 | 
			
		||||
                                                prefix=prefix),
 | 
			
		||||
            prefix=f"{prefix}.layers",
 | 
			
		||||
        )
 | 
			
		||||
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
			
		||||
@ -580,7 +586,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        config = vllm_config.model_config.hf_text_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.quant_config = quant_config
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
 | 
			
		||||
                         VllmConfig, get_current_vllm_config)
 | 
			
		||||
from vllm.distributed import (divide, get_ep_group, get_pp_group,
 | 
			
		||||
                              get_tensor_model_parallel_rank,
 | 
			
		||||
                              get_tensor_model_parallel_world_size)
 | 
			
		||||
                              get_tensor_model_parallel_world_size,
 | 
			
		||||
                              tensor_model_parallel_all_gather)
 | 
			
		||||
from vllm.forward_context import ForwardContext, get_forward_context
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.model_executor.layers.fla.ops import (
 | 
			
		||||
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import (
 | 
			
		||||
    default_weight_loader, sharded_weight_loader)
 | 
			
		||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
 | 
			
		||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
 | 
			
		||||
from vllm.model_executor.utils import set_weight_attrs
 | 
			
		||||
from vllm.platforms import current_platform
 | 
			
		||||
from vllm.sequence import IntermediateTensors
 | 
			
		||||
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
 | 
			
		||||
 | 
			
		||||
class Qwen3NextSparseMoeBlock(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: Qwen3NextConfig,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        enable_eplb: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        parallel_config = vllm_config.parallel_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
 | 
			
		||||
        self.tp_size = get_tensor_model_parallel_world_size()
 | 
			
		||||
 | 
			
		||||
        self.ep_group = get_ep_group().device_group
 | 
			
		||||
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
 | 
			
		||||
        self.ep_size = self.ep_group.size()
 | 
			
		||||
        self.n_routed_experts = config.num_experts
 | 
			
		||||
 | 
			
		||||
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
 | 
			
		||||
 | 
			
		||||
        if self.tp_size > config.num_experts:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Tensor parallel size {self.tp_size} is greater than "
 | 
			
		||||
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
 | 
			
		||||
        # Load balancing settings.
 | 
			
		||||
        vllm_config = get_current_vllm_config()
 | 
			
		||||
        eplb_config = vllm_config.parallel_config.eplb_config
 | 
			
		||||
        self.enable_eplb = enable_eplb
 | 
			
		||||
        self.enable_eplb = parallel_config.enable_eplb
 | 
			
		||||
 | 
			
		||||
        self.n_logical_experts = self.n_routed_experts
 | 
			
		||||
        self.n_redundant_experts = eplb_config.num_redundant_experts
 | 
			
		||||
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
 | 
			
		||||
                                quant_config=quant_config,
 | 
			
		||||
                                prefix=f"{prefix}.experts",
 | 
			
		||||
                                enable_eplb=self.enable_eplb,
 | 
			
		||||
                                num_redundant_experts=self.n_redundant_experts)
 | 
			
		||||
                                num_redundant_experts=self.n_redundant_experts,
 | 
			
		||||
                                is_sequence_parallel=self.is_sequence_parallel)
 | 
			
		||||
 | 
			
		||||
        self.gate = ReplicatedLinear(config.hidden_size,
 | 
			
		||||
                                     config.num_experts,
 | 
			
		||||
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
 | 
			
		||||
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        # NOTE: hidden_states can have either 1D or 2D shape.
 | 
			
		||||
        orig_shape = hidden_states.shape
 | 
			
		||||
        hidden_dim = hidden_states.shape[-1]
 | 
			
		||||
        num_tokens, hidden_dim = hidden_states.shape
 | 
			
		||||
        hidden_states = hidden_states.view(-1, hidden_dim)
 | 
			
		||||
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            hidden_states = sequence_parallel_chunk(hidden_states)
 | 
			
		||||
 | 
			
		||||
        shared_output = None
 | 
			
		||||
        if self.shared_expert is not None:
 | 
			
		||||
            shared_output = self.shared_expert(hidden_states)
 | 
			
		||||
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
 | 
			
		||||
 | 
			
		||||
        if shared_output is not None:
 | 
			
		||||
            final_hidden_states = final_hidden_states + shared_output
 | 
			
		||||
        if self.tp_size > 1:
 | 
			
		||||
 | 
			
		||||
        if self.is_sequence_parallel:
 | 
			
		||||
            final_hidden_states = tensor_model_parallel_all_gather(
 | 
			
		||||
                final_hidden_states, 0)
 | 
			
		||||
            final_hidden_states = final_hidden_states[:num_tokens]
 | 
			
		||||
        elif self.tp_size > 1:
 | 
			
		||||
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
 | 
			
		||||
                final_hidden_states)
 | 
			
		||||
 | 
			
		||||
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: Qwen3NextConfig,
 | 
			
		||||
        vllm_config: VllmConfig,
 | 
			
		||||
        layer_type: str,
 | 
			
		||||
        model_config: Optional[ModelConfig] = None,
 | 
			
		||||
        cache_config: Optional[CacheConfig] = None,
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        speculative_config: Optional[SpeculativeConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        enable_eplb: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.config = config
 | 
			
		||||
 | 
			
		||||
        config = vllm_config.model_config.hf_config
 | 
			
		||||
        model_config = vllm_config.model_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        speculative_config = vllm_config.speculative_config
 | 
			
		||||
 | 
			
		||||
        self.layer_type = layer_type
 | 
			
		||||
        self.layer_idx = extract_layer_index(prefix)
 | 
			
		||||
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
 | 
			
		||||
                config.num_experts > 0 and
 | 
			
		||||
            (self.layer_idx + 1) % config.decoder_sparse_step == 0):
 | 
			
		||||
            self.mlp = Qwen3NextSparseMoeBlock(
 | 
			
		||||
                config=config,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                vllm_config=vllm_config,
 | 
			
		||||
                prefix=f"{prefix}.mlp",
 | 
			
		||||
                enable_eplb=enable_eplb,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self.mlp = Qwen3NextMLP(
 | 
			
		||||
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    1,
 | 
			
		||||
                    1,
 | 
			
		||||
                    self.config.hidden_size,
 | 
			
		||||
                    config.hidden_size,
 | 
			
		||||
                    dtype=config.torch_dtype,
 | 
			
		||||
                ), )
 | 
			
		||||
            self.ffn_layer_scale = torch.nn.Parameter(
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    1,
 | 
			
		||||
                    1,
 | 
			
		||||
                    self.config.hidden_size,
 | 
			
		||||
                    config.hidden_size,
 | 
			
		||||
                    dtype=config.torch_dtype,
 | 
			
		||||
                ), )
 | 
			
		||||
 | 
			
		||||
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        config: Qwen3NextConfig = vllm_config.model_config.hf_config
 | 
			
		||||
        model_config = vllm_config.model_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        parallel_config = vllm_config.parallel_config
 | 
			
		||||
        lora_config = vllm_config.lora_config
 | 
			
		||||
        speculative_config = vllm_config.speculative_config
 | 
			
		||||
        enable_eplb = parallel_config.enable_eplb
 | 
			
		||||
        eplb_config = parallel_config.eplb_config
 | 
			
		||||
        self.num_redundant_experts = eplb_config.num_redundant_experts
 | 
			
		||||
 | 
			
		||||
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
 | 
			
		||||
 | 
			
		||||
        def get_layer(prefix: str):
 | 
			
		||||
            return Qwen3NextDecoderLayer(
 | 
			
		||||
                config,
 | 
			
		||||
                vllm_config,
 | 
			
		||||
                layer_type=config.layer_types[extract_layer_index(prefix)],
 | 
			
		||||
                model_config=model_config,
 | 
			
		||||
                cache_config=cache_config,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                speculative_config=speculative_config,
 | 
			
		||||
                prefix=prefix,
 | 
			
		||||
                enable_eplb=enable_eplb,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.start_layer, self.end_layer, self.layers = make_layers(
 | 
			
		||||
 | 
			
		||||
@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        model_config = vllm_config.model_config
 | 
			
		||||
        cache_config = vllm_config.cache_config
 | 
			
		||||
        quant_config = vllm_config.quant_config
 | 
			
		||||
        lora_config = vllm_config.lora_config
 | 
			
		||||
        config: Qwen3NextConfig = model_config.hf_config
 | 
			
		||||
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.layers = torch.nn.ModuleList(
 | 
			
		||||
            Qwen3NextDecoderLayer(
 | 
			
		||||
                config,
 | 
			
		||||
                vllm_config,
 | 
			
		||||
                layer_type="full_attention",
 | 
			
		||||
                model_config=model_config,
 | 
			
		||||
                cache_config=cache_config,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                prefix=f'{prefix}.layers.{idx}',
 | 
			
		||||
            ) for idx in range(self.num_mtp_layers))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -33,11 +33,14 @@ import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from transformers import BatchFeature
 | 
			
		||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
 | 
			
		||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
 | 
			
		||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
 | 
			
		||||
    smart_resize as image_smart_resize)
 | 
			
		||||
from transformers.models.qwen3_vl import (Qwen3VLProcessor,
 | 
			
		||||
                                          Qwen3VLVideoProcessor)
 | 
			
		||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
 | 
			
		||||
    Qwen3VLConfig, Qwen3VLVisionConfig)
 | 
			
		||||
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
 | 
			
		||||
    smart_resize as video_smart_resize)
 | 
			
		||||
from transformers.video_utils import VideoMetadata
 | 
			
		||||
 | 
			
		||||
from vllm.attention.layer import check_upstream_fa_availability
 | 
			
		||||
@ -84,6 +87,9 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
 | 
			
		||||
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
 | 
			
		||||
# Official recommended max pixels is 24576 * 32 * 32
 | 
			
		||||
_MAX_FRAMES_PER_VIDEO = 24576
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen3_VisionPatchEmbed(nn.Module):
 | 
			
		||||
 | 
			
		||||
@ -158,6 +164,8 @@ class Qwen3_VisionBlock(nn.Module):
 | 
			
		||||
        quant_config: Optional[QuantizationConfig] = None,
 | 
			
		||||
        prefix: str = "",
 | 
			
		||||
        use_data_parallel: bool = False,
 | 
			
		||||
        attn_backend: _Backend = _Backend.TORCH_SDPA,
 | 
			
		||||
        use_upstream_fa: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        if norm_layer is None:
 | 
			
		||||
@ -170,7 +178,9 @@ class Qwen3_VisionBlock(nn.Module):
 | 
			
		||||
            projection_size=dim,
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            prefix=f"{prefix}.attn",
 | 
			
		||||
            use_data_parallel=use_data_parallel)
 | 
			
		||||
            use_data_parallel=use_data_parallel,
 | 
			
		||||
            attn_backend=attn_backend,
 | 
			
		||||
            use_upstream_fa=use_upstream_fa)
 | 
			
		||||
        self.mlp = Qwen3_VisionMLP(dim,
 | 
			
		||||
                                   mlp_hidden_dim,
 | 
			
		||||
                                   act_fn=act_fn,
 | 
			
		||||
@ -287,19 +297,6 @@ class Qwen3_VisionTransformer(nn.Module):
 | 
			
		||||
        head_dim = self.hidden_size // self.num_heads
 | 
			
		||||
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
 | 
			
		||||
 | 
			
		||||
        self.blocks = nn.ModuleList([
 | 
			
		||||
            Qwen3_VisionBlock(
 | 
			
		||||
                dim=self.hidden_size,
 | 
			
		||||
                num_heads=self.num_heads,
 | 
			
		||||
                mlp_hidden_dim=vision_config.intermediate_size,
 | 
			
		||||
                act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
 | 
			
		||||
                norm_layer=norm_layer,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                prefix=f"{prefix}.blocks.{layer_idx}",
 | 
			
		||||
                use_data_parallel=use_data_parallel)
 | 
			
		||||
            for layer_idx in range(vision_config.depth)
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
        self.merger = Qwen3_VisionPatchMerger(
 | 
			
		||||
            d_model=vision_config.out_hidden_size,
 | 
			
		||||
            context_dim=self.hidden_size,
 | 
			
		||||
@ -325,10 +322,34 @@ class Qwen3_VisionTransformer(nn.Module):
 | 
			
		||||
 | 
			
		||||
        self.attn_backend = get_vit_attn_backend(
 | 
			
		||||
            head_size=head_dim, dtype=torch.get_default_dtype())
 | 
			
		||||
        use_upstream_fa = False
 | 
			
		||||
        if self.attn_backend != _Backend.FLASH_ATTN and \
 | 
			
		||||
            check_upstream_fa_availability(
 | 
			
		||||
                torch.get_default_dtype()):
 | 
			
		||||
            self.attn_backend = _Backend.FLASH_ATTN
 | 
			
		||||
            use_upstream_fa = True
 | 
			
		||||
 | 
			
		||||
        if self.attn_backend not in {
 | 
			
		||||
                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
 | 
			
		||||
                _Backend.ROCM_AITER_FA
 | 
			
		||||
        }:
 | 
			
		||||
            raise RuntimeError(
 | 
			
		||||
                f"Qwen3-VL does not support {self.attn_backend} backend now.")
 | 
			
		||||
 | 
			
		||||
        self.blocks = nn.ModuleList([
 | 
			
		||||
            Qwen3_VisionBlock(
 | 
			
		||||
                dim=self.hidden_size,
 | 
			
		||||
                num_heads=self.num_heads,
 | 
			
		||||
                mlp_hidden_dim=vision_config.intermediate_size,
 | 
			
		||||
                act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
 | 
			
		||||
                norm_layer=norm_layer,
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                prefix=f"{prefix}.blocks.{layer_idx}",
 | 
			
		||||
                use_data_parallel=use_data_parallel,
 | 
			
		||||
                attn_backend=self.attn_backend,
 | 
			
		||||
                use_upstream_fa=use_upstream_fa)
 | 
			
		||||
            for layer_idx in range(vision_config.depth)
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dtype(self) -> torch.dtype:
 | 
			
		||||
@ -569,11 +590,16 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
 | 
			
		||||
        image_height: int,
 | 
			
		||||
        num_frames: int = 2,
 | 
			
		||||
        do_resize: bool = True,
 | 
			
		||||
        image_processor: Optional[Qwen2VLImageProcessorFast],
 | 
			
		||||
        image_processor: Optional[Union[Qwen2VLImageProcessorFast,
 | 
			
		||||
                                        Qwen3VLVideoProcessor]],
 | 
			
		||||
    ) -> tuple[ImageSize, int]:
 | 
			
		||||
        if image_processor is None:
 | 
			
		||||
        if image_processor is None and num_frames > 1:
 | 
			
		||||
            image_processor = self.get_video_processor()
 | 
			
		||||
        elif image_processor is None:
 | 
			
		||||
            image_processor = self.get_image_processor()
 | 
			
		||||
 | 
			
		||||
        is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
 | 
			
		||||
 | 
			
		||||
        hf_config = self.get_hf_config()
 | 
			
		||||
        vision_config = hf_config.vision_config
 | 
			
		||||
        patch_size = vision_config.patch_size
 | 
			
		||||
@ -581,12 +607,22 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
 | 
			
		||||
        temporal_patch_size = vision_config.temporal_patch_size
 | 
			
		||||
 | 
			
		||||
        if do_resize:
 | 
			
		||||
            if is_video:
 | 
			
		||||
                smart_resize = video_smart_resize
 | 
			
		||||
                extra_kwargs = {
 | 
			
		||||
                    "num_frames": num_frames,
 | 
			
		||||
                    "temporal_factor": temporal_patch_size
 | 
			
		||||
                }
 | 
			
		||||
            else:
 | 
			
		||||
                smart_resize = image_smart_resize
 | 
			
		||||
                extra_kwargs = {}
 | 
			
		||||
            resized_height, resized_width = smart_resize(
 | 
			
		||||
                height=image_height,
 | 
			
		||||
                width=image_width,
 | 
			
		||||
                factor=patch_size * merge_size,
 | 
			
		||||
                min_pixels=image_processor.size["shortest_edge"],
 | 
			
		||||
                max_pixels=image_processor.size["longest_edge"],
 | 
			
		||||
                **extra_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
            preprocessed_size = ImageSize(width=resized_width,
 | 
			
		||||
                                          height=resized_height)
 | 
			
		||||
@ -605,6 +641,39 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
 | 
			
		||||
 | 
			
		||||
        return preprocessed_size, num_vision_tokens
 | 
			
		||||
 | 
			
		||||
    def _get_max_video_frames(self,
 | 
			
		||||
                              max_tokens: int,
 | 
			
		||||
                              start_num_frames: int = 2) -> int:
 | 
			
		||||
        return super()._get_max_video_frames(max_tokens,
 | 
			
		||||
                                             start_num_frames=start_num_frames)
 | 
			
		||||
 | 
			
		||||
    def get_num_frames_with_most_features(
 | 
			
		||||
        self,
 | 
			
		||||
        seq_len: int,
 | 
			
		||||
        mm_counts: Mapping[str, int],
 | 
			
		||||
    ) -> int:
 | 
			
		||||
        return super().get_num_frames_with_most_features(
 | 
			
		||||
            seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO)
 | 
			
		||||
 | 
			
		||||
    def get_max_video_tokens(
 | 
			
		||||
        self,
 | 
			
		||||
        seq_len: int,
 | 
			
		||||
        mm_counts: Mapping[str, int],
 | 
			
		||||
    ) -> int:
 | 
			
		||||
        target_width, target_height = self.get_image_size_with_most_features()
 | 
			
		||||
        video_soft_tokens = self.get_num_video_tokens(
 | 
			
		||||
            image_width=target_width,
 | 
			
		||||
            image_height=target_height,
 | 
			
		||||
            num_frames=self.get_num_frames_with_most_features(
 | 
			
		||||
                seq_len, mm_counts),
 | 
			
		||||
            image_processor=None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # NOTE: By default in Qwen3-VL, one video token is converted to
 | 
			
		||||
        # "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
 | 
			
		||||
        formatted_video_soft_tokens = video_soft_tokens * 12.5
 | 
			
		||||
        return int(formatted_video_soft_tokens)
 | 
			
		||||
 | 
			
		||||
    def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
 | 
			
		||||
                              video_fps: float, merge_size: int):
 | 
			
		||||
        if not isinstance(indices, list):
 | 
			
		||||
@ -674,6 +743,12 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
 | 
			
		||||
            self.info.get_image_size_with_most_features())
 | 
			
		||||
        target_num_frames = self.info.get_num_frames_with_most_features(
 | 
			
		||||
            seq_len, mm_counts)
 | 
			
		||||
        target_video_size, _ = self.info._get_vision_info(
 | 
			
		||||
            image_width=target_width,
 | 
			
		||||
            image_height=target_height,
 | 
			
		||||
            num_frames=target_num_frames,
 | 
			
		||||
            image_processor=self.info.get_video_processor(),
 | 
			
		||||
        )
 | 
			
		||||
        return {
 | 
			
		||||
            "image":
 | 
			
		||||
            self._get_dummy_images(width=target_width,
 | 
			
		||||
@ -681,8 +756,8 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
 | 
			
		||||
                                   num_images=num_images),
 | 
			
		||||
            "video":
 | 
			
		||||
            self._get_dummy_videos(
 | 
			
		||||
                width=target_width,
 | 
			
		||||
                height=target_height,
 | 
			
		||||
                width=target_video_size.width,
 | 
			
		||||
                height=target_video_size.height,
 | 
			
		||||
                num_frames=target_num_frames,
 | 
			
		||||
                num_videos=num_videos,
 | 
			
		||||
            ),
 | 
			
		||||
@ -1051,14 +1126,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.multimodal_config = multimodal_config
 | 
			
		||||
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
 | 
			
		||||
 | 
			
		||||
        self.visual = Qwen3_VisionTransformer(
 | 
			
		||||
            config.vision_config,
 | 
			
		||||
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            prefix=maybe_prefix(prefix, "visual"),
 | 
			
		||||
            use_data_parallel=self.use_data_parallel,
 | 
			
		||||
        )
 | 
			
		||||
        if not multimodal_config.get_limit_per_prompt("image") and \
 | 
			
		||||
            not multimodal_config.get_limit_per_prompt("video"):
 | 
			
		||||
            self.visual = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.visual = Qwen3_VisionTransformer(
 | 
			
		||||
                config.vision_config,
 | 
			
		||||
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                prefix=maybe_prefix(prefix, "visual"),
 | 
			
		||||
                use_data_parallel=self.use_data_parallel,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
 | 
			
		||||
                                                  prefix=maybe_prefix(
 | 
			
		||||
@ -1074,11 +1152,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
 | 
			
		||||
            config.vision_config.deepstack_visual_indexes
 | 
			
		||||
        ) if self.use_deepstack else 0
 | 
			
		||||
        # register buffer for deepstack
 | 
			
		||||
        self.deepstack_input_embeds = [
 | 
			
		||||
            torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
 | 
			
		||||
                        config.text_config.hidden_size)
 | 
			
		||||
            for _ in range(self.deepstack_num_level)
 | 
			
		||||
        ] if self.use_deepstack else None
 | 
			
		||||
        if self.use_deepstack and self.visual is not None:
 | 
			
		||||
            self.deepstack_input_embeds = [
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    vllm_config.scheduler_config.max_num_batched_tokens,
 | 
			
		||||
                    config.text_config.hidden_size)
 | 
			
		||||
                for _ in range(self.deepstack_num_level)
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            self.deepstack_input_embeds = None
 | 
			
		||||
        self.visual_dim = config.vision_config.out_hidden_size
 | 
			
		||||
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level
 | 
			
		||||
 | 
			
		||||
@ -1513,7 +1595,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
 | 
			
		||||
 | 
			
		||||
    def load_weights(self, weights: Iterable[tuple[str,
 | 
			
		||||
                                                   torch.Tensor]]) -> set[str]:
 | 
			
		||||
        loader = AutoWeightsLoader(self)
 | 
			
		||||
 | 
			
		||||
        skip_prefixes = []
 | 
			
		||||
        if self.visual is None:
 | 
			
		||||
            skip_prefixes.extend(["visual."])
 | 
			
		||||
        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
 | 
			
		||||
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
 | 
			
		||||
 | 
			
		||||
    def get_mm_mapping(self) -> MultiModelKeys:
 | 
			
		||||
 | 
			
		||||
@ -212,6 +212,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
 | 
			
		||||
                    # attempted to load as other weights later
 | 
			
		||||
                    is_expert_weight = True
 | 
			
		||||
                    name_mapped = name.replace(weight_name, param_name)
 | 
			
		||||
                    if is_pp_missing_parameter(name_mapped, self):
 | 
			
		||||
                        continue
 | 
			
		||||
                    if is_fused_expert:
 | 
			
		||||
                        loaded_weight = loaded_weight.transpose(-1,
 | 
			
		||||
                                                                -2)  # no bias
 | 
			
		||||
@ -230,8 +232,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
 | 
			
		||||
                                name_mapped, params_dict, loaded_weight,
 | 
			
		||||
                                shard_id, num_experts)
 | 
			
		||||
                    else:
 | 
			
		||||
                        if is_pp_missing_parameter(name_mapped, self):
 | 
			
		||||
                            continue
 | 
			
		||||
                        # Skip loading extra parameters for GPTQ/modelopt models
 | 
			
		||||
                        if name_mapped.endswith(
 | 
			
		||||
                                ignore_suffixes
 | 
			
		||||
@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
 | 
			
		||||
        self.multimodal_config = multimodal_config
 | 
			
		||||
        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
 | 
			
		||||
 | 
			
		||||
        self.visual = Qwen3_VisionTransformer(
 | 
			
		||||
            config.vision_config,
 | 
			
		||||
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
 | 
			
		||||
            quant_config=quant_config,
 | 
			
		||||
            prefix=maybe_prefix(prefix, "visual"),
 | 
			
		||||
            use_data_parallel=self.use_data_parallel,
 | 
			
		||||
        )
 | 
			
		||||
        if not multimodal_config.get_limit_per_prompt("image") and \
 | 
			
		||||
            not multimodal_config.get_limit_per_prompt("video"):
 | 
			
		||||
            self.visual = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.visual = Qwen3_VisionTransformer(
 | 
			
		||||
                config.vision_config,
 | 
			
		||||
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
 | 
			
		||||
                quant_config=quant_config,
 | 
			
		||||
                prefix=maybe_prefix(prefix, "visual"),
 | 
			
		||||
                use_data_parallel=self.use_data_parallel,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
 | 
			
		||||
                                                     prefix=maybe_prefix(
 | 
			
		||||
@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
 | 
			
		||||
            config.vision_config.deepstack_visual_indexes
 | 
			
		||||
        ) if self.use_deepstack else 0
 | 
			
		||||
        # register buffer for deepstack
 | 
			
		||||
        self.deepstack_input_embeds = [
 | 
			
		||||
            torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
 | 
			
		||||
                        config.text_config.hidden_size)
 | 
			
		||||
            for _ in range(self.deepstack_num_level)
 | 
			
		||||
        ] if self.use_deepstack else None
 | 
			
		||||
        if self.use_deepstack and self.visual is not None:
 | 
			
		||||
            self.deepstack_input_embeds = [
 | 
			
		||||
                torch.zeros(
 | 
			
		||||
                    vllm_config.scheduler_config.max_num_batched_tokens,
 | 
			
		||||
                    config.text_config.hidden_size)
 | 
			
		||||
                for _ in range(self.deepstack_num_level)
 | 
			
		||||
            ]
 | 
			
		||||
        else:
 | 
			
		||||
            self.deepstack_input_embeds = None
 | 
			
		||||
        self.visual_dim = config.vision_config.out_hidden_size
 | 
			
		||||
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level
 | 
			
		||||
 | 
			
		||||
@ -70,6 +70,7 @@ _TEXT_GENERATION_MODELS = {
 | 
			
		||||
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
 | 
			
		||||
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
 | 
			
		||||
    "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
 | 
			
		||||
    "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
 | 
			
		||||
    "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
 | 
			
		||||
    "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
 | 
			
		||||
    "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
 | 
			
		||||
 | 
			
		||||
@ -13,11 +13,14 @@ from transformers import PretrainedConfig
 | 
			
		||||
 | 
			
		||||
import vllm.envs as envs
 | 
			
		||||
from vllm.config import VllmConfig
 | 
			
		||||
from vllm.distributed import (get_tensor_model_parallel_rank,
 | 
			
		||||
                              get_tensor_model_parallel_world_size)
 | 
			
		||||
from vllm.logger import init_logger
 | 
			
		||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 | 
			
		||||
from vllm.multimodal import NestedTensors
 | 
			
		||||
from vllm.sequence import IntermediateTensors
 | 
			
		||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
 | 
			
		||||
from vllm.utils import (cdiv, direct_register_custom_op,
 | 
			
		||||
                        get_cuda_view_from_cpu_tensor, is_pin_memory_available,
 | 
			
		||||
                        is_uva_available)
 | 
			
		||||
 | 
			
		||||
logger = init_logger(__name__)
 | 
			
		||||
@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
 | 
			
		||||
        return hf_config.hidden_size
 | 
			
		||||
    text_config = hf_config.get_text_config()
 | 
			
		||||
    return text_config.hidden_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Chunk x along the num_tokens axis for sequence parallelism
 | 
			
		||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
 | 
			
		||||
# The output tensor can have a sequence length 0 at small input sequence lengths
 | 
			
		||||
# even though we explicitly pad to avoid this.
 | 
			
		||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    return torch.ops.vllm.sequence_parallel_chunk_impl(x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    tp_size = get_tensor_model_parallel_world_size()
 | 
			
		||||
    tp_rank = get_tensor_model_parallel_rank()
 | 
			
		||||
 | 
			
		||||
    # all_gather needs the sequence length to be divisible by tp_size
 | 
			
		||||
    seq_len = x.size(0)
 | 
			
		||||
    remainder = seq_len % tp_size
 | 
			
		||||
    if remainder != 0:
 | 
			
		||||
        pad_len = tp_size - remainder
 | 
			
		||||
        y = nn.functional.pad(x, (0, 0, 0, pad_len))
 | 
			
		||||
    else:
 | 
			
		||||
        y = x
 | 
			
		||||
 | 
			
		||||
    chunk = y.shape[0] // tp_size
 | 
			
		||||
    start = tp_rank * chunk
 | 
			
		||||
    return torch.narrow(y, 0, start, chunk)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    tp_size = get_tensor_model_parallel_world_size()
 | 
			
		||||
    seq_len = cdiv(x.size(0), tp_size)
 | 
			
		||||
    shape = list(x.shape)
 | 
			
		||||
    shape[0] = seq_len
 | 
			
		||||
    out = torch.empty(shape, dtype=x.dtype, device=x.device)
 | 
			
		||||
    return out
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
direct_register_custom_op(
 | 
			
		||||
    op_name="sequence_parallel_chunk_impl",
 | 
			
		||||
    op_func=sequence_parallel_chunk_impl,
 | 
			
		||||
    fake_impl=sequence_parallel_chunk_impl_fake,
 | 
			
		||||
    tags=(torch.Tag.needs_fixed_stride_order, ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,7 @@ class MediaConnector:
 | 
			
		||||
        connection: HTTPConnection = global_http_connection,
 | 
			
		||||
        *,
 | 
			
		||||
        allowed_local_media_path: str = "",
 | 
			
		||||
        allowed_media_domains: Optional[list[str]] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
@ -82,6 +83,9 @@ class MediaConnector:
 | 
			
		||||
            allowed_local_media_path_ = None
 | 
			
		||||
 | 
			
		||||
        self.allowed_local_media_path = allowed_local_media_path_
 | 
			
		||||
        if allowed_media_domains is None:
 | 
			
		||||
            allowed_media_domains = []
 | 
			
		||||
        self.allowed_media_domains = allowed_media_domains
 | 
			
		||||
 | 
			
		||||
    def _load_data_url(
 | 
			
		||||
        self,
 | 
			
		||||
@ -115,6 +119,14 @@ class MediaConnector:
 | 
			
		||||
 | 
			
		||||
        return media_io.load_file(filepath)
 | 
			
		||||
 | 
			
		||||
    def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
 | 
			
		||||
        if self.allowed_media_domains and url_spec.hostname not in \
 | 
			
		||||
            self.allowed_media_domains:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"The URL must be from one of the allowed domains: "
 | 
			
		||||
                f"{self.allowed_media_domains}. Input URL domain: "
 | 
			
		||||
                f"{url_spec.hostname}")
 | 
			
		||||
 | 
			
		||||
    def load_from_url(
 | 
			
		||||
        self,
 | 
			
		||||
        url: str,
 | 
			
		||||
@ -125,8 +137,14 @@ class MediaConnector:
 | 
			
		||||
        url_spec = urlparse(url)
 | 
			
		||||
 | 
			
		||||
        if url_spec.scheme.startswith("http"):
 | 
			
		||||
            self._assert_url_in_allowed_media_domains(url_spec)
 | 
			
		||||
 | 
			
		||||
            connection = self.connection
 | 
			
		||||
            data = connection.get_bytes(url, timeout=fetch_timeout)
 | 
			
		||||
            data = connection.get_bytes(
 | 
			
		||||
                url,
 | 
			
		||||
                timeout=fetch_timeout,
 | 
			
		||||
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return media_io.load_bytes(data)
 | 
			
		||||
 | 
			
		||||
@ -150,8 +168,14 @@ class MediaConnector:
 | 
			
		||||
        loop = asyncio.get_running_loop()
 | 
			
		||||
 | 
			
		||||
        if url_spec.scheme.startswith("http"):
 | 
			
		||||
            self._assert_url_in_allowed_media_domains(url_spec)
 | 
			
		||||
 | 
			
		||||
            connection = self.connection
 | 
			
		||||
            data = await connection.async_get_bytes(url, timeout=fetch_timeout)
 | 
			
		||||
            data = await connection.async_get_bytes(
 | 
			
		||||
                url,
 | 
			
		||||
                timeout=fetch_timeout,
 | 
			
		||||
                allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
 | 
			
		||||
            )
 | 
			
		||||
            future = loop.run_in_executor(global_thread_pool,
 | 
			
		||||
                                          media_io.load_bytes, data)
 | 
			
		||||
            return await future
 | 
			
		||||
 | 
			
		||||
@ -93,11 +93,14 @@ class CpuPlatform(Platform):
 | 
			
		||||
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
 | 
			
		||||
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
 | 
			
		||||
                             block_size: int, use_v1: bool, use_mla: bool,
 | 
			
		||||
                             has_sink: bool) -> str:
 | 
			
		||||
                             has_sink: bool, use_sparse: bool) -> str:
 | 
			
		||||
        if selected_backend and selected_backend != _Backend.TORCH_SDPA:
 | 
			
		||||
            logger.info("Cannot use %s backend on CPU.", selected_backend)
 | 
			
		||||
        if use_mla:
 | 
			
		||||
            raise NotImplementedError("MLA is not supported on CPU.")
 | 
			
		||||
        if use_sparse:
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                "Sparse Attention is not supported on CPU.")
 | 
			
		||||
        logger.info("Using Torch SDPA backend.")
 | 
			
		||||
        if not use_v1:
 | 
			
		||||
            raise ValueError("CPU backend only supports V1.")
 | 
			
		||||
 | 
			
		||||
@ -129,6 +129,8 @@ class CudaPlatformBase(Platform):
 | 
			
		||||
        # TODO(lucas): handle this more gracefully
 | 
			
		||||
        # Note: model_config may be None during testing
 | 
			
		||||
        if model_config is not None and model_config.use_mla:
 | 
			
		||||
            use_sparse = hasattr(vllm_config.model_config.hf_config,
 | 
			
		||||
                                 "index_topk")
 | 
			
		||||
            # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
 | 
			
		||||
            # then we default to FlashMLA backend for non-blackwell GPUs,
 | 
			
		||||
            # else we default to CutlassMLA. For each case, we force the
 | 
			
		||||
@ -175,6 +177,12 @@ class CudaPlatformBase(Platform):
 | 
			
		||||
                    "Forcing kv cache block size to 64 for FlashInferMLA "
 | 
			
		||||
                    "backend.")
 | 
			
		||||
 | 
			
		||||
            # TODO(Chen): remove this hacky code
 | 
			
		||||
            if use_sparse and cache_config.block_size != 64:
 | 
			
		||||
                cache_config.block_size = 64
 | 
			
		||||
                logger.info(
 | 
			
		||||
                    "Forcing kv cache block size to 64 for FlashMLASparse "
 | 
			
		||||
                    "backend.")
 | 
			
		||||
        # lazy import to avoid circular import
 | 
			
		||||
        from vllm.config import CUDAGraphMode
 | 
			
		||||
 | 
			
		||||
@ -205,6 +213,12 @@ class CudaPlatformBase(Platform):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_vit_attn_backend(cls, head_size: int,
 | 
			
		||||
                             dtype: torch.dtype) -> _Backend:
 | 
			
		||||
 | 
			
		||||
        # For Blackwell GPUs, force TORCH_SDPA for now.
 | 
			
		||||
        # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
 | 
			
		||||
        if cls.has_device_capability(100):
 | 
			
		||||
            return _Backend.TORCH_SDPA
 | 
			
		||||
 | 
			
		||||
        if dtype not in (torch.float16, torch.bfloat16):
 | 
			
		||||
            return _Backend.XFORMERS
 | 
			
		||||
 | 
			
		||||
@ -225,7 +239,7 @@ class CudaPlatformBase(Platform):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
 | 
			
		||||
                             kv_cache_dtype, block_size, use_v1, use_mla,
 | 
			
		||||
                             has_sink) -> str:
 | 
			
		||||
                             has_sink, use_sparse) -> str:
 | 
			
		||||
        if use_mla:
 | 
			
		||||
            if not use_v1:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
@ -235,6 +249,11 @@ class CudaPlatformBase(Platform):
 | 
			
		||||
            from vllm.attention.ops.flashmla import is_flashmla_supported
 | 
			
		||||
            from vllm.attention.utils.fa_utils import flash_attn_supports_mla
 | 
			
		||||
 | 
			
		||||
            if use_sparse:
 | 
			
		||||
                logger.info_once("Using Sparse MLA backend on V1 engine.")
 | 
			
		||||
                return ("vllm.v1.attention.backends.mla.flashmla_sparse."
 | 
			
		||||
                        "FlashMLASparseBackend")
 | 
			
		||||
 | 
			
		||||
            use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
 | 
			
		||||
                selected_backend is None and cls.is_device_capability(100)
 | 
			
		||||
                and block_size == 128)
 | 
			
		||||
 | 
			
		||||
@ -194,7 +194,7 @@ class Platform:
 | 
			
		||||
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
 | 
			
		||||
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
 | 
			
		||||
                             block_size: int, use_v1: bool, use_mla: bool,
 | 
			
		||||
                             has_sink: bool) -> str:
 | 
			
		||||
                             has_sink: bool, use_sparse: bool) -> str:
 | 
			
		||||
        """Get the attention backend class of a device."""
 | 
			
		||||
        return ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -195,7 +195,10 @@ class RocmPlatform(Platform):
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
 | 
			
		||||
                             kv_cache_dtype, block_size, use_v1, use_mla,
 | 
			
		||||
                             has_sink) -> str:
 | 
			
		||||
                             has_sink, use_sparse) -> str:
 | 
			
		||||
        if use_sparse:
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                "Sparse Attention is not supported on ROCm.")
 | 
			
		||||
        if use_mla:
 | 
			
		||||
            if not use_v1:
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
 | 
			
		||||
@ -49,7 +49,10 @@ class TpuPlatform(Platform):
 | 
			
		||||
    def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
 | 
			
		||||
                             dtype: torch.dtype, kv_cache_dtype: Optional[str],
 | 
			
		||||
                             block_size: int, use_v1: bool, use_mla: bool,
 | 
			
		||||
                             has_sink) -> str:
 | 
			
		||||
                             has_sink, use_sparse) -> str:
 | 
			
		||||
        if use_sparse:
 | 
			
		||||
            raise NotImplementedError(
 | 
			
		||||
                "Sparse Attention is not supported on TPU.")
 | 
			
		||||
        if selected_backend != _Backend.PALLAS:
 | 
			
		||||
            logger.info("Cannot use %s backend on TPU.", selected_backend)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user