mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
34 Commits
debug-logs
...
v0.11.0
Author | SHA1 | Date | |
---|---|---|---|
b8b302cde4 | |||
f71952c1c4 | |||
d1007767c5 | |||
c75c2e70d6 | |||
9d9a2b77f1 | |||
6040e0b6c0 | |||
05bf0c52a1 | |||
c536881a7c | |||
ebce361c07 | |||
e4beabd2c8 | |||
febb688356 | |||
a1825fe645 | |||
bab9231bf1 | |||
c214d699fd | |||
c3dfb0f6dd | |||
83f3c9beae | |||
d0b178cef1 | |||
b3230e1ac0 | |||
03df0fb5d2 | |||
9471879bd4 | |||
ab5b6459df | |||
8ce5d3198d | |||
09c2cbc04a | |||
4c347044c9 | |||
19e7ab7315 | |||
6de3d431d9 | |||
b14773bd64 | |||
26a7a33b88 | |||
5aa5811a16 | |||
c2fa2d4dc9 | |||
32335c8b34 | |||
04c2b26972 | |||
ee10d7e6ff | |||
bb79c4da2f |
@ -48,7 +48,7 @@ steps:
|
||||
agents:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||
@ -76,7 +76,7 @@ steps:
|
||||
queue: arm64_cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
|
||||
|
||||
# Add job to create multi-arch manifest
|
||||
|
@ -584,8 +584,9 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV32ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
|
@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||
# Only build FlashMLA kernels if we are building for something compatible with
|
||||
# sm90a
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
|
||||
set(SUPPORT_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
|
||||
list(APPEND SUPPORT_ARCHS 9.0a)
|
||||
endif()
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
|
||||
list(APPEND SUPPORT_ARCHS 10.0a)
|
||||
endif()
|
||||
|
||||
|
||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||
if(FLASH_MLA_ARCHS)
|
||||
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
|
||||
|
||||
set(FlashMLA_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
||||
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_SOURCES
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc)
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
set(FlashMLA_Extension_INCLUDES
|
||||
${flashmla_SOURCE_DIR}/csrc
|
||||
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||
)
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${FlashMLA_Extension_SOURCES}"
|
||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_C
|
||||
DESTINATION vllm
|
||||
@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
|
||||
define_gpu_extension_target(
|
||||
_flashmla_extension_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${FlashMLA_Extension_SOURCES}
|
||||
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||
target_compile_options(_flashmla_extension_C PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||
else()
|
||||
# Create an empty target for setup.py when not targeting sm90a systems
|
||||
# Create empty targets for setup.py when not targeting sm90a systems
|
||||
add_custom_target(_flashmla_C)
|
||||
add_custom_target(_flashmla_extension_C)
|
||||
endif()
|
||||
|
||||
|
@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_page_table(
|
||||
blk_coord,
|
||||
problem_shape,
|
||||
params.mainloop,
|
||||
shared_storage.tensors,
|
||||
pipeline_page_table, pipeline_pt_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_cpasync(
|
||||
blk_coord,
|
||||
@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
params.mainloop_params,
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv,
|
||||
local_split_kv,
|
||||
/* must be shared pipe */
|
||||
pipeline_page_table, pipeline_pt_consumer_state
|
||||
);
|
||||
@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_tma</* paged= */ true>(
|
||||
blk_coord,
|
||||
@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||
}
|
||||
@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
CUTLASS_PRAGMA_NO_UNROLL
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
load_tma<false>(
|
||||
blk_coord,
|
||||
@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
shared_storage.tensors,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||
}
|
||||
@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto local_split_kv = params.split_kv;
|
||||
auto local_split_kv = params.split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
mma(blk_coord,
|
||||
problem_shape,
|
||||
@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_producer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_consumer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_producer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||
auto blk_coord = tile_scheduler.get_block_coord();
|
||||
auto problem_shape = params.problem_shape;
|
||||
auto split_kv = params.split_kv;
|
||||
auto local_split_kv = split_kv;
|
||||
auto split_kv = params.split_kv;
|
||||
auto local_split_kv = split_kv;
|
||||
if (params.mainloop.ptr_seq != nullptr) {
|
||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
if (params.ptr_split_kv != nullptr) {
|
||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||
}
|
||||
}
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
if (local_split_kv <= get<3>(blk_coord))
|
||||
continue;
|
||||
compute(
|
||||
blk_coord,
|
||||
@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
||||
pipeline_p_mma, pipeline_p_mma_producer_state,
|
||||
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
||||
local_split_kv
|
||||
local_split_kv
|
||||
);
|
||||
}
|
||||
|
||||
@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
||||
cutlass::arch::NamedBarrier(
|
||||
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
|
||||
kNamedBarrierEpilogue
|
||||
).arrive();
|
||||
).arrive_and_wait();
|
||||
|
||||
return;
|
||||
}
|
||||
|
@ -56,3 +56,11 @@ void cp_gather_cache(
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// Indexer K quantization and cache function
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt);
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cfloat> // FLT_MIN
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
@ -396,6 +397,180 @@ __global__ void concat_and_cache_mla_kernel(
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_ds_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int entry_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
const int64_t dst_idx_start =
|
||||
block_idx * block_stride + block_offset * entry_stride;
|
||||
|
||||
// Create 4 tile scales in shared memory
|
||||
__shared__ float smem[20];
|
||||
float* shard_abs_max = smem;
|
||||
float* tile_scales = smem + 16;
|
||||
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
|
||||
// Cast kv_cache to 16_bit for RoPE values
|
||||
scalar_t* kv_cache_16bit =
|
||||
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
|
||||
|
||||
// The last 64 threads handle the RoPE part
|
||||
if (threadIdx.x >= kv_lora_rank) {
|
||||
const int8_t pe_idx = threadIdx.x - kv_lora_rank;
|
||||
const int64_t src_idx = token_idx * k_pe_stride + pe_idx;
|
||||
// RoPE values start after the packed 8-bit NoPE values and the
|
||||
// 32-bit scales
|
||||
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx;
|
||||
kv_cache_16bit[dst_idx] = k_pe[src_idx];
|
||||
return;
|
||||
}
|
||||
|
||||
// Determine the scale for each chunk of NoPE
|
||||
const int16_t tile_idx = threadIdx.x >> 7;
|
||||
const int16_t warp_idx = (threadIdx.x & 127) >> 5;
|
||||
const int16_t lane_idx = threadIdx.x & 31;
|
||||
|
||||
// Load the NoPE element for this thread into registers
|
||||
const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x;
|
||||
const scalar_t src_val = kv_c[src_idx];
|
||||
|
||||
// Warp-level reduction to find the max absolute value in the warp
|
||||
float max_abs = fabsf(src_val);
|
||||
#pragma unroll
|
||||
for (int offset = 16; offset > 0; offset /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(UINT64_MAX, max_abs, offset));
|
||||
#else
|
||||
max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
// The first lane of each warp in each tile writes the max_abs of this part
|
||||
// of the tile to shared memory
|
||||
if (lane_idx == 0) {
|
||||
shard_abs_max[tile_idx * 4 + warp_idx] = max_abs;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// The first lane of the first warp in each tile computes the scale for the
|
||||
// tile and writes it to shared memory and to kv_cache
|
||||
if (warp_idx == 0 && lane_idx == 0) {
|
||||
float4 shard_abs_max_vec =
|
||||
reinterpret_cast<float4*>(shard_abs_max)[tile_idx];
|
||||
float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y),
|
||||
fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) /
|
||||
448.f;
|
||||
|
||||
// Avoid division by zero in `scaled_convert`
|
||||
tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN);
|
||||
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
|
||||
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
|
||||
kv_cache_32bit[dst_idx] = tile_scales[tile_idx];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now all threads in the block scale and write their element
|
||||
const float scale_val = tile_scales[tile_idx];
|
||||
const int64_t dst_idx = dst_idx_start + threadIdx.x;
|
||||
kv_cache[dst_idx] =
|
||||
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
|
||||
src_val, scale_val);
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void indexer_k_quant_and_cache_kernel(
|
||||
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int head_dim, // dimension of each head
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
|
||||
threadIdx.y * blockDim.x + threadIdx.x) *
|
||||
VEC_SIZE;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
const int64_t block_idx = slot_idx / cache_block_size;
|
||||
const int64_t block_offset = slot_idx % cache_block_size;
|
||||
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
|
||||
return;
|
||||
}
|
||||
|
||||
float2 k_val = (reinterpret_cast<const float2*>(
|
||||
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
|
||||
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
|
||||
float amax = 0.0f;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
|
||||
// Reduced amax
|
||||
for (int mask = 16; mask > 0; mask /= 2) {
|
||||
#ifdef USE_ROCM
|
||||
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
|
||||
#else
|
||||
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
||||
#endif
|
||||
}
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
}
|
||||
|
||||
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
|
||||
block_offset * head_dim + head_dim_idx;
|
||||
for (int i = 0; i < VEC_SIZE; i++) {
|
||||
kv_cache[dst_offset + i] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
const int64_t dst_scale_idx =
|
||||
block_idx * cache_block_size * cache_stride +
|
||||
cache_block_size * head_dim +
|
||||
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
|
||||
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -438,7 +613,7 @@ void reshape_and_cache(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
CALL_RESHAPE_AND_CACHE);
|
||||
}
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
@ -509,6 +684,18 @@ void reshape_and_cache_flash(
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
// KV_T is the data type of key and value tensors.
|
||||
// CACHE_T is the stored data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
@ -531,20 +718,44 @@ void concat_and_cache_mla(
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
|
||||
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
|
||||
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
|
||||
TORCH_CHECK(kv_c.itemsize() == 2,
|
||||
"kv_c.itemsize() must be 2 for fp8_ds_mla");
|
||||
TORCH_CHECK(k_pe.itemsize() == 2,
|
||||
"k_pe.itemsize() must be 2 for fp8_ds_mla");
|
||||
} else {
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
}
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||
dim3 grid(num_tokens);
|
||||
// For the NoPE part, each tile of 128 elements is handled by 4 warps
|
||||
// (128 threads). There are 4 total tiles, so 16 warps (512 threads).
|
||||
// The first thread of the first warp in each tile writes the scale
|
||||
// value for the tile. The RoPE part (last 64 elements) is handled
|
||||
// by another 2 warps (64 threads).
|
||||
// So in total, we use 18 warps (576 threads) per block.
|
||||
dim3 block(576);
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_DS_MLA);
|
||||
} else {
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@ -922,3 +1133,42 @@ void cp_gather_cache(
|
||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||
}
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(k.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
|
||||
cache_block_size, cache_stride, use_ue8m0);
|
||||
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
int64_t quant_block_size, // quantization block size
|
||||
const std::string& scale_fmt) {
|
||||
int num_tokens = k.size(0);
|
||||
int head_dim = k.size(1);
|
||||
int cache_block_size = kv_cache.size(1);
|
||||
int cache_stride = kv_cache.size(2);
|
||||
bool use_ue8m0 = scale_fmt == "ue8m0";
|
||||
|
||||
TORCH_CHECK(k.device() == kv_cache.device(),
|
||||
"k and kv_cache must be on the same device");
|
||||
TORCH_CHECK(k.device() == slot_mapping.device(),
|
||||
"k and slot_mapping must be on the same device");
|
||||
TORCH_CHECK(head_dim % quant_block_size == 0,
|
||||
"head_dim must be divisible by quant_block_size");
|
||||
|
||||
constexpr int vec_size = 4;
|
||||
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
|
||||
(quant_block_size * vec_size));
|
||||
dim3 block(32, vec_size);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
|
||||
CALL_INDEXER_K_QUANT_AND_CACHE);
|
||||
}
|
||||
|
@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else if (KV_DTYPE == "fp8_ds_mla") { \
|
||||
if (SRC_DTYPE == at::ScalarType::Float) { \
|
||||
FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::Half) { \
|
||||
FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
|
||||
FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, \
|
||||
"Unsupported input type of kv cache: ", SRC_DTYPE); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
|
||||
} \
|
||||
|
@ -713,6 +713,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
||||
"slot_mapping, "
|
||||
"int quant_block_size, str kv_cache_dtype) -> ()");
|
||||
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
|
||||
&indexer_k_quant_and_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
|
@ -14,6 +14,11 @@ ARG PYTHON_VERSION=3.12
|
||||
#
|
||||
# Example:
|
||||
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
|
||||
# Important: We build with an old version of Ubuntu to maintain broad
|
||||
# compatibility with other Linux OSes. The main reason for this is that the
|
||||
# glibc version is baked into the distro, and binaries built with one glibc
|
||||
# version are not backwards compatible with OSes that use an earlier version.
|
||||
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
@ -75,34 +80,19 @@ ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Install Python and other dependencies
|
||||
# Install system dependencies and uv, then create Python virtual environment
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
||||
fi ; \
|
||||
else \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done ; \
|
||||
fi \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo python3-pip \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
|
||||
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
|
||||
&& ln -s /opt/venv/bin/python3 /usr/bin/python3 \
|
||||
&& ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
|
||||
&& ln -s /opt/venv/bin/pip /usr/bin/pip \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
@ -111,9 +101,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
# Activate virtual environment and add uv to PATH
|
||||
ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
|
||||
ENV VIRTUAL_ENV="/opt/venv"
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
@ -142,7 +132,7 @@ WORKDIR /workspace
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/cuda.txt \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# cuda arch list used by torch
|
||||
@ -172,7 +162,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
COPY . .
|
||||
@ -269,7 +259,7 @@ COPY requirements/lint.txt requirements/lint.txt
|
||||
COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/dev.txt \
|
||||
uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
@ -404,6 +394,9 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
export FLASHINFER_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# HACK: We need these to run flashinfer.aot before installing flashinfer, get from the package in the future
|
||||
uv pip install --system cuda-python==$(echo $CUDA_VERSION | cut -d. -f1,2) pynvml==$(echo $CUDA_VERSION | cut -d. -f1) nvidia-nvshmem-cu$(echo $CUDA_VERSION | cut -d. -f1)
|
||||
# Build AOT kernels
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
|
@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0
|
||||
#
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
|
||||
ARG CUDA_VERSION=12.8.0
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG TARGETPLATFORM
|
||||
|
@ -6,6 +6,13 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
|
||||
We are actively iterating on multi-modal support. See [this RFC](gh-issue:4194) for upcoming changes,
|
||||
and [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) if you have any feedback or feature requests.
|
||||
|
||||
!!! tip
|
||||
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
||||
|
||||
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions.
|
||||
|
||||
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
||||
|
||||
## Offline Inference
|
||||
|
||||
To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
|
||||
|
@ -60,6 +60,15 @@ Key points from the PyTorch security guide:
|
||||
- Implement proper authentication and authorization for management interfaces
|
||||
- Follow the principle of least privilege for all system components
|
||||
|
||||
### 4. **Restrict Domains Access for Media URLs:**
|
||||
|
||||
Restrict domains that vLLM can access for media URLs by setting
|
||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||
|
||||
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP
|
||||
redirects from being followed to bypass domain restrictions.
|
||||
|
||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||
|
||||
While vLLM is designed to allow unsafe network services to be isolated to
|
||||
|
@ -54,6 +54,7 @@ def parse_args():
|
||||
"--method",
|
||||
type=str,
|
||||
default="eagle",
|
||||
choices=["ngram", "eagle", "eagle3", "mtp"],
|
||||
)
|
||||
parser.add_argument("--num-spec-tokens", type=int, default=2)
|
||||
parser.add_argument("--prompt-lookup-max", type=int, default=5)
|
||||
@ -118,9 +119,9 @@ def main(args):
|
||||
"prompt_lookup_max": args.prompt_lookup_max,
|
||||
"prompt_lookup_min": args.prompt_lookup_min,
|
||||
}
|
||||
elif args.method.endswith("mtp"):
|
||||
elif args.method == "mtp":
|
||||
speculative_config = {
|
||||
"method": args.method,
|
||||
"method": "mtp",
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
}
|
||||
else:
|
||||
|
4
setup.py
4
setup.py
@ -322,6 +322,8 @@ class precompiled_wheel_utils:
|
||||
"vllm/_C.abi3.so",
|
||||
"vllm/_moe_C.abi3.so",
|
||||
"vllm/_flashmla_C.abi3.so",
|
||||
"vllm/_flashmla_extension_C.abi3.so",
|
||||
"vllm/_sparse_flashmla_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
|
||||
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
|
||||
"vllm/cumem_allocator.abi3.so",
|
||||
@ -589,6 +591,8 @@ if _is_cuda():
|
||||
# not targeting a hopper system
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_C", optional=True))
|
||||
ext_modules.append(
|
||||
CMakeExtension(name="vllm._flashmla_extension_C", optional=True))
|
||||
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))
|
||||
|
||||
if _build_custom_ops():
|
||||
|
@ -191,7 +191,6 @@ class AttentionQuantPatternModel(torch.nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
),
|
||||
layer_names=[self.attn.layer_name],
|
||||
vllm_config=self.vllm_config,
|
||||
|
@ -45,6 +45,7 @@ class MockModelConfig:
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
|
@ -240,6 +240,7 @@ class MockModelConfig:
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format,
|
||||
resolve_chat_template_kwargs,
|
||||
resolve_hf_chat_template)
|
||||
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
|
||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
|
||||
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
|
||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
assert isinstance(chat_template, str)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, expected_kwargs",
|
||||
[
|
||||
(
|
||||
QWEN2VL_MODEL_ID,
|
||||
{
|
||||
"add_vision_id", "add_generation_prompt",
|
||||
"continue_final_message", "tools"
|
||||
},
|
||||
),
|
||||
(
|
||||
QWEN3_MODEL_ID,
|
||||
{
|
||||
"enable_thinking", "add_generation_prompt",
|
||||
"continue_final_message", "tools"
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
|
||||
expected_kwargs):
|
||||
"""checks that chat_template is a dict type for HF models."""
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
tools = ([{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema,
|
||||
},
|
||||
}])
|
||||
|
||||
chat_template_kwargs = {
|
||||
# both unused
|
||||
"unsed_kwargs_1": 123,
|
||||
"unsed_kwargs_2": "abc",
|
||||
# should not appear
|
||||
"chat_template": "{% Hello world! %}",
|
||||
# used by tokenizer
|
||||
"continue_final_message": True,
|
||||
"tools": tools,
|
||||
# both used by Qwen2-VL and Qwen3
|
||||
"add_generation_prompt": True,
|
||||
# only used by Qwen2-VL
|
||||
"add_vision_id": True,
|
||||
# only used by Qwen3
|
||||
"enable_thinking": True,
|
||||
}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model,
|
||||
tokenizer=model_info.tokenizer or model,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
|
||||
# Build the tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
|
||||
|
||||
|
||||
# NOTE: Qwen2-Audio default chat template is specially defined inside
|
||||
# processor class instead of using `tokenizer_config.json`
|
||||
# yapf: disable
|
||||
|
@ -593,6 +593,119 @@ def test_concat_and_cache_mla(
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_ds_mla(
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if dtype.itemsize != 2:
|
||||
pytest.skip("ds_mla only supports 16-bit input")
|
||||
kv_cache_dtype = "fp8_ds_mla"
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
|
||||
|
||||
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device)
|
||||
|
||||
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
|
||||
tile_data = torch.zeros(128, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
ref_cache_16bit = ref_cache_slice.view(dtype)
|
||||
ref_cache_32bit = ref_cache_slice.view(torch.float32)
|
||||
|
||||
kv_c_data = kv_c[i]
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = (tile_idx + 1) * 128
|
||||
tile_data[:] = kv_c_data[tile_start:tile_end]
|
||||
|
||||
# tile_scale = tile_data.amax().to(torch.float32) / 448.
|
||||
# NOTE: Using torch's amax() gives different results,
|
||||
# so this must be manually computed.
|
||||
tile_data_float = tile_data.to(torch.float32)
|
||||
manual_max = abs(tile_data_float[0])
|
||||
for j in range(1, 128):
|
||||
manual_max = max(manual_max, abs(tile_data_float[j]))
|
||||
tile_scale = manual_max / 448.
|
||||
|
||||
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
|
||||
|
||||
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
|
||||
tile_data,
|
||||
tile_scale.item(),
|
||||
kv_dtype="fp8")
|
||||
|
||||
for j in range(qk_rope_head_dim):
|
||||
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla,
|
||||
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
kv_cache_slice = kv_cache[block_idx, block_offset]
|
||||
ref_cache_slice = ref_cache[block_idx, block_offset]
|
||||
|
||||
kv_nope = kv_cache_slice[:kv_lora_rank]
|
||||
ref_nope = ref_cache_slice[:kv_lora_rank]
|
||||
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
ref_scales = ref_cache_slice.view(
|
||||
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
|
||||
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
|
||||
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
|
||||
|
279
tests/kernels/attention/test_deepgemm_attention.py
Normal file
279
tests/kernels/attention/test_deepgemm_attention.py
Normal file
@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits, get_num_sms,
|
||||
get_paged_mqa_logits_metadata)
|
||||
|
||||
|
||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (num_blocks, block_size, 1, head_dim)
|
||||
num_blocks, block_size, num_heads, head_dim = x.shape
|
||||
assert num_heads == 1
|
||||
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
x_fp8 = torch.empty(
|
||||
(num_blocks, block_size * (head_dim + 4)),
|
||||
device=x.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
x_fp8[:, :block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
|
||||
x_fp8[:,
|
||||
block_size * head_dim:] = sf.view(num_blocks,
|
||||
block_size).view(dtype=torch.uint8)
|
||||
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
||||
|
||||
|
||||
def per_custom_dims_cast_to_fp8(
|
||||
x: torch.Tensor, dims: tuple,
|
||||
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
||||
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled, sf.squeeze()
|
||||
|
||||
|
||||
def _generate_cp_test_data(seq_len: int, seq_len_kv: int):
|
||||
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
|
||||
chunk_size = seq_len // 2
|
||||
cp_size = seq_len_kv // seq_len
|
||||
cp_id = cp_size // 3
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
for i in range(chunk_size):
|
||||
ke[i] = cp_id * chunk_size + i
|
||||
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
|
||||
return ks, ke
|
||||
|
||||
|
||||
def _ref_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
):
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
def test_deepgemm_fp8_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
num_heads, head_dim = 32, 128
|
||||
for seq_len in (512, ):
|
||||
for seq_len_kv in (1024, ):
|
||||
for disable_cp in (False, True):
|
||||
q = torch.randn(
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv = torch.randn(seq_len_kv,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16)
|
||||
weights = torch.randn(seq_len,
|
||||
num_heads,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
if disable_cp:
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.arange(seq_len, dtype=torch.int,
|
||||
device="cuda") + (seq_len_kv - seq_len)
|
||||
else:
|
||||
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
|
||||
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
q=q,
|
||||
kv=kv,
|
||||
weights=weights,
|
||||
cu_seqlen_ks=ks,
|
||||
cu_seqlen_ke=ke,
|
||||
)
|
||||
|
||||
ref_neginf_mask = ref_logits == float("-inf")
|
||||
neginf_mask = logits == float("-inf")
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
|
||||
|
||||
def _ref_fp8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
batch_size, next_n, _, _ = q.size()
|
||||
_, block_size, _, _ = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens_list = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens_list[i]
|
||||
q_offsets = torch.arange(context_len - next_n,
|
||||
context_len,
|
||||
device="cuda")
|
||||
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
|
||||
0, 1).contiguous())
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size,
|
||||
(block_rk + 1) * block_size,
|
||||
device="cuda",
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
|
||||
<= q_offsets[:, None])
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n:(i + 1) * next_n,
|
||||
block_rk * block_size:(block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
def test_deepgemm_fp8_paged_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
max_model_len = 4096
|
||||
for batch_size, next_n in [(4, 1), (2, 2)]:
|
||||
for heads, index_dim in [(32, 128)]:
|
||||
for avg_kv in (2048, ):
|
||||
num_blocks, blocksize = max_model_len * 2, 64
|
||||
|
||||
q = torch.randn(
|
||||
(batch_size, next_n, heads, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv_cache = torch.randn(
|
||||
(num_blocks, blocksize, 1, index_dim),
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
weights = torch.randn(
|
||||
(batch_size * next_n, heads),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
context_lens = (torch.randint(int(0.8 * avg_kv),
|
||||
int(1.2 * avg_kv),
|
||||
(batch_size, )).cuda().to(
|
||||
torch.int32))
|
||||
max_block_len = ((context_lens.max().item() + blocksize - 1) //
|
||||
blocksize * blocksize)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, max_block_len),
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
counter = 0
|
||||
block_idx_pool = list(range(num_blocks))
|
||||
random.shuffle(block_idx_pool)
|
||||
for i in range(batch_size):
|
||||
ctx_len = int(context_lens[i].item())
|
||||
for j in range((ctx_len + blocksize - 1) // blocksize):
|
||||
block_tables[i][j] = block_idx_pool[counter]
|
||||
counter += 1
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = get_paged_mqa_logits_metadata(
|
||||
context_lens, blocksize, get_num_sms())
|
||||
logits = fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_paged_mqa_logits(
|
||||
q,
|
||||
kv_cache,
|
||||
weights,
|
||||
context_lens,
|
||||
block_tables,
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
positions = (torch.arange(max_model_len,
|
||||
device="cuda").unsqueeze(0).expand(
|
||||
batch_size * next_n, -1))
|
||||
row_indices = (
|
||||
torch.arange(batch_size * next_n, device="cuda") // next_n)
|
||||
next_n_offset = (
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n)
|
||||
mask = positions <= (context_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
|
||||
logits = logits.masked_fill(~mask, 0)
|
||||
ref_logits = ref_logits.masked_fill(~mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
descale_k = None
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
return flash_mla_with_kvcache(q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
|
119
tests/kernels/attention/test_flashmla_sparse.py
Normal file
119
tests/kernels/attention/test_flashmla_sparse.py
Normal file
@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def _cuda_sm90_available() -> bool:
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
return major == 9
|
||||
|
||||
|
||||
def test_sparse_flashmla_metadata_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 128
|
||||
num_heads_k = 1
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
topk = 128
|
||||
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
assert tile_md.dtype == torch.int32
|
||||
assert num_splits.dtype == torch.int32
|
||||
|
||||
|
||||
def test_sparse_flashmla_decode_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 1
|
||||
head_dim_k = 576
|
||||
head_dim_v = 512
|
||||
num_heads_k = 1
|
||||
page_block_size = 64
|
||||
bytes_per_token = 656
|
||||
topk = 128
|
||||
|
||||
# Metadata
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
# q_heads_per_hk = num_heads_q // num_heads_k
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
|
||||
# Inputs
|
||||
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
|
||||
dtype=torch.bfloat16,
|
||||
device=device)
|
||||
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
indices = torch.zeros((batch_size, seqlen_q, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
block_table = torch.zeros((batch_size, 128),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
out, lse = fm.flash_mla_with_kvcache(q,
|
||||
k_cache,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
head_dim_v,
|
||||
tile_md,
|
||||
num_splits,
|
||||
indices=indices,
|
||||
is_fp8_kvcache=True)
|
||||
assert out.shape[0] == batch_size
|
||||
assert out.shape[-1] == head_dim_v
|
||||
assert lse.shape[0] == batch_size
|
||||
|
||||
|
||||
def test_sparse_flashmla_prefill_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
|
||||
device = torch.device("cuda")
|
||||
s_q = 1
|
||||
s_kv = 1
|
||||
h_q = 64 # kernel expects multiple of 64
|
||||
h_kv = 1
|
||||
d_qk = 576
|
||||
d_v = 512
|
||||
topk = 128
|
||||
|
||||
q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device)
|
||||
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
|
||||
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
|
||||
|
||||
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
|
||||
d_v)
|
||||
assert out.shape == (s_q, h_q, d_v)
|
||||
assert max_logits.shape == (s_q, h_q)
|
||||
assert lse.shape == (s_q, h_q)
|
245
tests/kernels/attention/test_pack_unpack_triton.py
Normal file
245
tests/kernels/attention/test_pack_unpack_triton.py
Normal file
@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
|
||||
|
||||
def test_pack_seq_basic_fp8():
|
||||
"""Test basic functionality of pack_seq_triton with fp8 and 3D tensors."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors (N, H, D)
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4)
|
||||
(10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8)
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32)
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check output shape and properties
|
||||
expected_shape = (B, max(lengths_list), H, D)
|
||||
assert packed.shape == expected_shape
|
||||
assert packed.dtype == dtype
|
||||
assert packed.device == x.device
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = sum(lengths_list[:b])
|
||||
seq_len = lengths_list[b]
|
||||
|
||||
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
|
||||
actual_data = packed[b, :seq_len].to(torch.float32)
|
||||
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_custom_padding_fp8():
|
||||
"""Test pack_seq_triton with custom padding values for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test with different padding values
|
||||
for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]:
|
||||
result = pack_seq_triton(x, lengths, pad_value=pad_value)
|
||||
|
||||
# Check valid data
|
||||
for b in range(B):
|
||||
start_idx = b * 10
|
||||
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
|
||||
actual_data = result[b, :10].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
# Check padding (fp8 has limited range, so check for large values)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
if pad_value < 0:
|
||||
assert torch.all(padded_data < -50) # Large negative values
|
||||
elif pad_value > 0:
|
||||
assert torch.all(padded_data > 50) # Large positive values
|
||||
else:
|
||||
assert torch.allclose(padded_data,
|
||||
torch.zeros_like(padded_data),
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_default_negative_inf_padding_fp8():
|
||||
"""Test that pack_seq_triton uses -inf padding by default for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
# B = 2
|
||||
N, H, D = 20, 8, 16
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check that padding is large negative values (fp8 representation of -inf)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
assert torch.all(
|
||||
padded_data < -100) # fp8 -inf is represented as large negative number
|
||||
|
||||
|
||||
def test_pack_seq_edge_cases_fp8():
|
||||
"""Test pack_seq_triton with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (1, 10, 8, 16)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 1, 4, 8)
|
||||
|
||||
# Test with different sequence lengths
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
result = pack_seq_triton(x, lengths)
|
||||
assert result.shape == (3, 7, 8, 16)
|
||||
|
||||
|
||||
def test_pack_seq_different_block_sizes_fp8():
|
||||
"""Test pack_seq_triton with different block sizes for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 100, 16, 32, 4
|
||||
lengths = torch.tensor([25, 25, 25, 25], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
# Test different block sizes
|
||||
for block_t, block_d in [(32, 32), (64, 64), (128, 128)]:
|
||||
result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d)
|
||||
|
||||
assert result.shape == (B, 25, H, D)
|
||||
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = b * 25
|
||||
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
|
||||
actual_data = result[b, :25].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_shape_consistency():
|
||||
"""Test that pack_seq_triton maintains shape consistency."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
N, H, D, B = 20, 8, 16, 2
|
||||
lengths = torch.tensor([10, 10], device=device)
|
||||
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
|
||||
result = pack_seq_triton(x, lengths)
|
||||
|
||||
# Check shape consistency
|
||||
assert result.shape[0] == B # Batch dimension
|
||||
assert result.shape[1] == lengths.max().item() # Max sequence length
|
||||
assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved
|
||||
|
||||
|
||||
def test_pack_unpack_roundtrip_fp8():
|
||||
"""Test that pack -> unpack gives us back the original data for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test cases with 3D tensors
|
||||
test_cases = [
|
||||
(6, 8, 4, 2, [3, 3]),
|
||||
(10, 4, 8, 3, [2, 4, 4]),
|
||||
(20, 16, 32, 4, [5, 5, 5, 5]),
|
||||
(15, 8, 16, 3, [7, 5, 3]),
|
||||
]
|
||||
|
||||
for N, H, D, B, lengths_list in test_cases:
|
||||
# Create input tensor with small values for fp8
|
||||
x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor(lengths_list, device=device)
|
||||
|
||||
# Pack the data
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
|
||||
# Unpack the data
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
|
||||
# Check that we get back the original data (within fp8 precision)
|
||||
assert unpacked.shape == x.shape
|
||||
x_f32 = x.to(torch.float32)
|
||||
unpacked_f32 = unpacked.to(torch.float32)
|
||||
assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Unpack without explicit start locations (computed in kernel)
|
||||
unpacked_with_loc = unpack_seq_triton(packed, lengths)
|
||||
assert_close(x_f32,
|
||||
unpacked_with_loc.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-2)
|
||||
|
||||
|
||||
def test_unpack_seq_triton_edge_cases_fp8():
|
||||
"""Test unpack function with edge cases for fp8."""
|
||||
device = "cuda"
|
||||
dtype = torch.float8_e4m3fn
|
||||
|
||||
# Test with single batch element
|
||||
x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([10], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([1, 1, 1], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
# Only compare the first 3 elements that were actually packed
|
||||
assert_close(x[:3].to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
lengths = torch.tensor([5, 7, 3], device=device)
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
@ -207,6 +207,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True),
|
||||
"DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501
|
||||
trust_remote_code=True),
|
||||
"DeepseekV32ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3.2-Exp"),
|
||||
"Ernie4_5ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT",
|
||||
min_transformers_version="4.54"),
|
||||
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT",
|
||||
|
@ -8,7 +8,8 @@ import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.kv_cache_utils import (generate_scheduler_kv_cache_config,
|
||||
get_kv_cache_configs)
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
from ..utils import create_new_process_for_each_test
|
||||
@ -62,11 +63,13 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
# Avoid calling model.forward()
|
||||
def _initialize_kv_caches_v1(self, vllm_config):
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
scheduler_kv_cache_config = get_kv_cache_configs(
|
||||
kv_cache_configs = get_kv_cache_configs(
|
||||
vllm_config,
|
||||
kv_cache_specs,
|
||||
[10 * GiB_bytes],
|
||||
)[0]
|
||||
)
|
||||
scheduler_kv_cache_config = generate_scheduler_kv_cache_config(
|
||||
kv_cache_configs)
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
return 1, 0, scheduler_kv_cache_config
|
||||
|
@ -66,7 +66,12 @@ async def test_fetch_image_http(image_url: str):
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
|
||||
raw_image_url: str, suffix: str):
|
||||
connector = MediaConnector()
|
||||
connector = MediaConnector(
|
||||
# Domain restriction should not apply to data URLs.
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
])
|
||||
url_image = url_images[raw_image_url]
|
||||
|
||||
try:
|
||||
@ -387,3 +392,29 @@ def test_argsort_mm_positions(case):
|
||||
modality_idxs = argsort_mm_positions(mm_positions)
|
||||
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
async def test_allowed_media_domains(video_url: str, num_frames: int):
|
||||
connector = MediaConnector(
|
||||
media_io_kwargs={"video": {
|
||||
"num_frames": num_frames,
|
||||
}},
|
||||
allowed_media_domains=[
|
||||
"www.bogotobogo.com",
|
||||
"github.com",
|
||||
])
|
||||
|
||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||
assert np.array_equal(video_sync, video_async)
|
||||
assert metadata_sync == metadata_async
|
||||
|
||||
disallowed_url = "https://upload.wikimedia.org/wikipedia/commons/4/47/PNG_transparency_demonstration_1.png"
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = connector.fetch_video(disallowed_url)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_, _ = await connector.fetch_video_async(disallowed_url)
|
||||
|
@ -26,5 +26,5 @@ class DummyPlatform(Platform):
|
||||
|
||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink):
|
||||
has_sink, use_sparse):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -10,6 +11,7 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache(
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True) -> torch.Tensor:
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
|
||||
"""Create and prepopulate an MLA KV cache with context data.
|
||||
|
||||
Args:
|
||||
@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache(
|
||||
common_attn_metadata: Common attention metadata
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
kv_cache_dtype: Optional kv cache dtype string. When set to
|
||||
"fp8_ds_mla" the cache is populated using the
|
||||
fp8 DeepSeek MLA layout via concat_and_cache_mla.
|
||||
scale: Scaling factor forwarded to concat_and_cache_mla when the
|
||||
fp8 cache layout is requested.
|
||||
|
||||
Returns:
|
||||
MLA KV cache tensor
|
||||
@ -105,23 +114,61 @@ def create_and_prepopulate_kv_cache(
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
if not kv_c_contexts:
|
||||
raise ValueError("kv_c_contexts cannot be empty when using"
|
||||
" fp8_ds_mla cache dtype")
|
||||
kv_lora_rank = kv_c_contexts[0].shape[-1]
|
||||
rope_dim = k_pe_contexts[0].shape[-1]
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
scale_tensor = (scale
|
||||
if isinstance(scale, torch.Tensor) else torch.tensor(
|
||||
scale, dtype=torch.float32, device=device))
|
||||
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
# Start from block_id=1 since block_id=0 is considered the null block
|
||||
start_block_idx = 1
|
||||
for i in range(batch_size):
|
||||
kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
|
||||
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
context_len = kv_c_context.shape[0]
|
||||
if context_len == 0:
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
continue
|
||||
|
||||
start = start_block_idx * block_size
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
slots = torch.arange(context_len, device=device,
|
||||
dtype=torch.long) + start
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_context,
|
||||
k_pe_context.squeeze(1),
|
||||
kv_cache,
|
||||
slots,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale_tensor,
|
||||
)
|
||||
else:
|
||||
kv_context = torch.cat(
|
||||
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
# Stay block aligned and allocate enough blocks for the new tokens
|
||||
start_block_idx += cdiv(int(seq_lens[i]), block_size)
|
||||
|
448
tests/v1/attention/test_sparse_mla_backends.py
Normal file
448
tests/v1/attention/test_sparse_mla_backends.py
Normal file
@ -0,0 +1,448 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for the FlashMLA sparse backend utilities."""
|
||||
|
||||
import math
|
||||
from types import MethodType, SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_mla_backends import (
|
||||
BATCH_SPECS, BatchSpec, MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache)
|
||||
from tests.v1.attention.utils import (create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl, FlashMLASparseMetadata)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
name: BATCH_SPECS[name]
|
||||
for name in [
|
||||
"mixed_small",
|
||||
"mixed_medium",
|
||||
"small_prefill",
|
||||
"medium_prefill",
|
||||
"single_prefill",
|
||||
]
|
||||
}
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
|
||||
query_lens=[256] * 2)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2)
|
||||
|
||||
|
||||
def _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
|
||||
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
|
||||
|
||||
# The first kv_lora_rank bytes store FP8 latent values with one scale per
|
||||
# 128 element tile written as float32 right after the latent payload.
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank,
|
||||
dtype=torch.float16,
|
||||
device=cache_slice.device)
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = tile_start + 128
|
||||
ops.convert_fp8(latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8")
|
||||
latent = latent.to(dtype)
|
||||
|
||||
rope_offset = kv_lora_rank // 2 + 8
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
|
||||
return latent, rope_vals.clone()
|
||||
|
||||
|
||||
def _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
|
||||
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
|
||||
|
||||
if kv_c.numel() == 0:
|
||||
return kv_c.clone(), k_pe.clone()
|
||||
|
||||
kv_lora_rank = kv_c.shape[-1]
|
||||
rope_dim = k_pe.shape[-1]
|
||||
num_tokens = kv_c.shape[0]
|
||||
num_blocks = max(1, math.ceil(num_tokens / block_size))
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
|
||||
tmp_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=kv_c.device)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.long,
|
||||
device=kv_c.device)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c,
|
||||
k_pe,
|
||||
tmp_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale)
|
||||
|
||||
dequant_kv_c = torch.empty_like(kv_c)
|
||||
dequant_k_pe = torch.empty_like(k_pe)
|
||||
|
||||
for token_idx in range(num_tokens):
|
||||
slot = slot_mapping[token_idx].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
cache_slice = tmp_cache[block_idx, block_offset]
|
||||
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
|
||||
dequant_kv_c[token_idx] = latent
|
||||
dequant_k_pe[token_idx] = rope_vals
|
||||
|
||||
return dequant_kv_c, dequant_k_pe
|
||||
|
||||
|
||||
def test_sparse_backend_metadata_registration():
|
||||
backend = FlashMLASparseBackend
|
||||
|
||||
assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
|
||||
assert backend.get_metadata_cls() is FlashMLASparseMetadata
|
||||
assert backend.get_impl_cls() is FlashMLASparseImpl
|
||||
|
||||
dtype_list = backend.get_supported_dtypes()
|
||||
assert torch.bfloat16 in dtype_list
|
||||
|
||||
shape = backend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=64,
|
||||
num_kv_heads=1,
|
||||
head_size=576)
|
||||
assert shape == (2, 64, 576)
|
||||
|
||||
|
||||
def test_sparse_decode_metadata_filters_prefill_indices():
|
||||
prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
|
||||
metadata = FlashMLASparseDecodeAndContextMetadata(
|
||||
scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
|
||||
num_splits=torch.tensor([1, 1], dtype=torch.int32),
|
||||
cache_lens=torch.tensor([10, 12], dtype=torch.int32),
|
||||
prefill_context_lengths=prefill_context_lengths,
|
||||
)
|
||||
|
||||
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
|
||||
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(
|
||||
indices)
|
||||
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
|
||||
dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
|
||||
dtype=torch.int32)
|
||||
|
||||
assert torch.equal(context_indices, expected_context)
|
||||
assert torch.equal(new_token_indices, expected_new_tokens)
|
||||
|
||||
|
||||
def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
|
||||
dummy_layer = object()
|
||||
q = torch.zeros((2, 1, 3))
|
||||
k_c = torch.zeros((2, 3))
|
||||
k_pe = torch.zeros((2, 1, 1))
|
||||
kv_cache = torch.zeros((1, 1, 1))
|
||||
output = torch.ones((2, 4))
|
||||
|
||||
result = FlashMLASparseImpl.forward(impl,
|
||||
dummy_layer,
|
||||
q,
|
||||
k_c,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata=None,
|
||||
output=output)
|
||||
|
||||
assert result is output
|
||||
assert torch.all(result == 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
kv_cache_dtype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
|
||||
|
||||
# Model hyper-parameters (kept intentionally small for the unit test)
|
||||
num_heads = 128
|
||||
kv_lora_rank = 512
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
v_head_dim = 128
|
||||
head_size = kv_lora_rank + qk_rope_head_dim
|
||||
topk_tokens = 2048
|
||||
|
||||
max_seqlen = max(batch_spec.seq_lens)
|
||||
total_cache_tokens = sum(batch_spec.seq_lens)
|
||||
block_size = 64
|
||||
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048,
|
||||
cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size)
|
||||
model_config = vllm_config.model_config
|
||||
model_config.hf_config = SimpleNamespace(
|
||||
attn_module_list_cfg=[{
|
||||
"topk_tokens": topk_tokens
|
||||
}])
|
||||
model_config.hf_text_config = SimpleNamespace(
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
model_type="deepseek_v2",
|
||||
)
|
||||
model_config.dtype = dtype
|
||||
model_config.get_num_attention_heads = MethodType(
|
||||
lambda self, parallel_config: num_heads, model_config)
|
||||
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
|
||||
model_config)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size,
|
||||
model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None,
|
||||
model_config)
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = 1.0 / math.sqrt(head_size)
|
||||
|
||||
# Shared MLA projection weights to keep reference and backend in sync
|
||||
W_UK = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UV = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# Build synthetic decode-only workload
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
|
||||
all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
reference_outputs = []
|
||||
|
||||
kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
for i in range(batch_spec.batch_size):
|
||||
s_len = seq_lens[i]
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
|
||||
q_c = torch.rand(q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len,
|
||||
1,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
k_pe_full.squeeze(1),
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
|
||||
q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
|
||||
k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)
|
||||
|
||||
attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
|
||||
causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
|
||||
attn_mask[:, ctx_len:] = causal_mask
|
||||
|
||||
q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
|
||||
k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
reference_outputs.append(sdpa_out.flatten(start_dim=-2))
|
||||
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
all_k_pe_vllm.append(k_pe_full[ctx_len:])
|
||||
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
|
||||
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
|
||||
sdpa_reference = torch.cat(reference_outputs, dim=0)
|
||||
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
vllm_config.cache_config.block_size,
|
||||
device,
|
||||
arange_block_indices=True)
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=False,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
scale=kv_cache_scale,
|
||||
)
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device,
|
||||
dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = (debug_indices <= token_positions)
|
||||
debug_indices = torch.where(causal_mask, debug_indices,
|
||||
torch.full_like(debug_indices, -1))
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
|
||||
-1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_supported()
|
||||
if not ok:
|
||||
pytest.skip(reason)
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
|
||||
output_size=num_heads *
|
||||
(qk_nope_head_dim + v_head_dim),
|
||||
bias=False).to(device=device,
|
||||
dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(metadata.num_actual_tokens,
|
||||
num_heads * v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
backend_output = impl.forward(layer,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
assert backend_output.dtype == sdpa_reference.dtype
|
||||
assert torch.isfinite(backend_output).all()
|
||||
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_reference,
|
||||
rtol=0.5,
|
||||
atol=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,max_buf,start,expected",
|
||||
[
|
||||
# Basic split: totals per chunk ≤ max_buf
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Non-zero start index
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would
|
||||
# overflow
|
||||
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
|
||||
# All requests fit in a single chunk
|
||||
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
|
||||
# Large buffer with non-zero start
|
||||
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
|
||||
],
|
||||
)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf, start)
|
||||
assert out == expected
|
@ -168,7 +168,6 @@ def create_standard_kv_cache_spec(
|
||||
vllm_config.parallel_config),
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
use_mla=vllm_config.model_config.use_mla,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
)
|
||||
|
||||
|
@ -24,7 +24,8 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
make_block_hash_with_group_id)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
KVCacheTensor, SlidingWindowSpec,
|
||||
KVCacheTensor, MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
sliding_window=None):
|
||||
return FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
sliding_window=1):
|
||||
return SlidingWindowSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
@ -894,7 +891,6 @@ def test_merge_kv_cache_spec():
|
||||
num_kv_heads=full_spec.num_kv_heads,
|
||||
head_size=full_spec.head_size,
|
||||
dtype=full_spec.dtype,
|
||||
use_mla=full_spec.use_mla,
|
||||
sliding_window=1,
|
||||
),
|
||||
]
|
||||
@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len,
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
# Estimate the maximum model length, 16384 model_len need 8GB
|
||||
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
|
||||
@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
sliding_window_spec = SlidingWindowSpec(
|
||||
@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config():
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
sliding_window=1024,
|
||||
)
|
||||
|
||||
@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config():
|
||||
KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec())
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def new_mla_spec(cache_dtype_str=None):
|
||||
return MLAAttentionSpec(block_size=16,
|
||||
num_kv_heads=16,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
cache_dtype_str=cache_dtype_str)
|
||||
|
||||
|
||||
def test_merge_mla_spec():
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(),
|
||||
new_mla_spec(),
|
||||
]
|
||||
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
|
||||
assert mla_spec == new_mla_spec()
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
]
|
||||
mla_spec = kv_cache_specs[0].merge(kv_cache_specs)
|
||||
assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla")
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_mla_spec(cache_dtype_str=None),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
kv_cache_specs = [
|
||||
new_kv_cache_spec(),
|
||||
new_mla_spec(),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
||||
kv_cache_specs = [
|
||||
new_mla_spec(cache_dtype_str="fp8_ds_mla"),
|
||||
new_kv_cache_spec(),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
kv_cache_specs[0].merge(kv_cache_specs)
|
||||
|
@ -76,7 +76,7 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -90,7 +90,7 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
["layer1"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
["layer2"],
|
||||
@ -98,7 +98,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
KVCacheGroupSpec(
|
||||
@ -107,7 +106,6 @@ def make_kv_cache_config_hybrid_model(block_size: int,
|
||||
1,
|
||||
1,
|
||||
torch.float32,
|
||||
False,
|
||||
sliding_window=2 * block_size),
|
||||
),
|
||||
],
|
||||
@ -1338,7 +1336,6 @@ def test_eagle_with_sliding_window():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=block_size,
|
||||
use_mla=False,
|
||||
)
|
||||
manager = KVCacheManager(
|
||||
KVCacheConfig(
|
||||
|
@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -100,7 +99,6 @@ def test_sliding_window_possible_cached_prefix():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -165,7 +163,6 @@ def test_chunked_local_attention_remove_skipped_blocks():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
@ -217,7 +214,6 @@ def test_sliding_window_remove_skipped_blocks():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4,
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
|
||||
@ -285,7 +281,6 @@ def test_get_num_blocks_to_allocate():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
sliding_window=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
@ -308,7 +303,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate():
|
||||
head_size=1,
|
||||
dtype=torch.float32,
|
||||
attention_chunk_size=4, # Placeholder value, not related to test result
|
||||
use_mla=False,
|
||||
)
|
||||
|
||||
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
|
||||
|
@ -15,6 +15,8 @@ from vllm.assets.image import VLM_IMAGES_DIR
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MTP_SIMILARITY_RATE = 0.8
|
||||
|
||||
|
||||
def get_test_prompts(mm_enabled: bool):
|
||||
prompt_types = ["repeat", "sentence"]
|
||||
@ -222,3 +224,66 @@ def test_eagle_correctness(
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
|
||||
(("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False),
|
||||
(("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False),
|
||||
],
|
||||
ids=["mimo", "deepseek"])
|
||||
def test_mtp_correctness(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sampling_config: SamplingParams,
|
||||
model_setup: tuple[str, str, int],
|
||||
mm_enabled: bool,
|
||||
):
|
||||
# Generate test prompts inside the function instead of using fixture
|
||||
test_prompts = get_test_prompts(mm_enabled)
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using MTP speculative decoding.
|
||||
model_setup: (method, model_name, tp_size)
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
method, model_name, tp_size = model_setup
|
||||
|
||||
ref_llm = LLM(model=model_name,
|
||||
max_model_len=2048,
|
||||
tensor_parallel_size=tp_size,
|
||||
trust_remote_code=True)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
spec_llm = LLM(
|
||||
model=model_name,
|
||||
trust_remote_code=True,
|
||||
tensor_parallel_size=tp_size,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"num_speculative_tokens": 1,
|
||||
"max_model_len": 2048,
|
||||
},
|
||||
max_model_len=2048,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||
|
||||
# Heuristic: expect at least 80% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(MTP_SIMILARITY_RATE * len(ref_outputs))
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
@ -836,8 +836,7 @@ def test_engine_core_proc_instantiation_cuda_empty(
|
||||
mock_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
dtype=torch.float16)
|
||||
|
||||
mock_executor.get_kv_cache_specs.return_value = [{
|
||||
"default": mock_spec
|
||||
|
@ -255,8 +255,9 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
# gpu_model_runner. Here we just hardcode some dummy values.
|
||||
self.slot_size_bytes = 4096
|
||||
self.block_len = self.slot_size_bytes * self.block_size
|
||||
slot_size_bytes = 4096
|
||||
self.slot_size_per_layer = [slot_size_bytes]
|
||||
self.block_len_per_layer = [slot_size_bytes * self.block_size]
|
||||
self.num_blocks = 1
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
|
||||
@ -268,7 +269,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=self.block_len,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
# `self.kv_cache_layout` is only forced to HND when vllm engine
|
||||
# is started. We mock HND here.
|
||||
@ -485,8 +486,8 @@ class TestNixlHandshake:
|
||||
worker = connector.connector_worker
|
||||
|
||||
# Minimal local registration params used by add_remote_agent
|
||||
worker.slot_size_bytes = 4096
|
||||
worker.block_len = worker.slot_size_bytes * worker.block_size
|
||||
worker.slot_size_per_layer = [4096]
|
||||
worker.block_len_per_layer = [4096 * worker.block_size]
|
||||
worker.num_blocks = 1
|
||||
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
|
||||
|
||||
@ -498,7 +499,7 @@ class TestNixlHandshake:
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
block_len=worker.block_len,
|
||||
block_lens=worker.block_len_per_layer,
|
||||
attn_backend_name=worker.backend_name,
|
||||
kv_cache_layout=mismatched_layout,
|
||||
)
|
||||
|
@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
"target_attn_1": mock.MagicMock(),
|
||||
"target_attn_2": mock.MagicMock()
|
||||
}
|
||||
target_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
# Draft model has one extra attention layer compared to target model
|
||||
all_attn_layers = {
|
||||
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
|
||||
}
|
||||
|
||||
all_indx_layers: dict[str, mock.MagicMock] = {}
|
||||
|
||||
# Make mock_get_layers return different values for each call
|
||||
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers, target_indx_layers, all_attn_layers,
|
||||
all_indx_layers
|
||||
]
|
||||
|
||||
# Setup mock for pp group to return the appropriate value for world size
|
||||
mock_pp_group = mock.MagicMock()
|
||||
@ -658,6 +664,9 @@ def test_propose_tree(spec_token_tree):
|
||||
# Mock runner for attention metadata building.
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.runner.attn_groups.append([mock.MagicMock()])
|
||||
proposer.runner.attn_groups[0][0].metadata_builders = [
|
||||
attn_metadata_builder
|
||||
]
|
||||
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
|
||||
attn_metadata_builder
|
||||
proposer._get_attention_metadata_builder = mock.MagicMock(
|
||||
|
201
tests/v1/spec_decode/test_mtp.py
Normal file
201
tests/v1/spec_decode/test_mtp.py
Normal file
@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
get_attention_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VllmConfig)
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
|
||||
mimo_7b_dir = "XiaomiMiMo/MiMo-7B-Base"
|
||||
|
||||
|
||||
def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
|
||||
"""Create an MTP proposer with unified model configuration."""
|
||||
model_config = ModelConfig(model=mimo_7b_dir,
|
||||
runner="generate",
|
||||
max_model_len=100,
|
||||
trust_remote_code=True)
|
||||
|
||||
speculative_config = SpeculativeConfig(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
model=mimo_7b_dir,
|
||||
method="mtp",
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
speculative_config=speculative_config,
|
||||
device_config=DeviceConfig(device=current_platform.device_type),
|
||||
parallel_config=ParallelConfig(),
|
||||
load_config=LoadConfig(),
|
||||
scheduler_config=SchedulerConfig())
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config,
|
||||
device=current_platform.device_type)
|
||||
|
||||
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
|
||||
mock_get_pp_group):
|
||||
"""Test MTP-specific model loading with unified model approach."""
|
||||
|
||||
# Setup mocks
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
target_attn_layers = {"target_attn_1": mock.MagicMock()}
|
||||
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
|
||||
target_indexer_layers: dict = {}
|
||||
all_indexer_layers: dict = {}
|
||||
|
||||
mock_get_layers.side_effect = [
|
||||
target_attn_layers, target_indexer_layers, all_attn_layers,
|
||||
all_indexer_layers
|
||||
]
|
||||
|
||||
mock_pp_group = mock.MagicMock()
|
||||
mock_pp_group.world_size = 1
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Create target model
|
||||
class _TargetModelStub(LlamaForCausalLM):
|
||||
model: mock.MagicMock
|
||||
lm_head: mock.MagicMock
|
||||
|
||||
target_model = mock.create_autospec(_TargetModelStub, instance=True)
|
||||
target_model.model = mock.MagicMock()
|
||||
target_model.model.embed_tokens.weight.shape = (131072, 4096)
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create MTP proposer
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens=4)
|
||||
proposer.load_model(target_model)
|
||||
|
||||
# Verify MTP-specific behavior:
|
||||
# Model is loaded
|
||||
mock_get_model.assert_called_once()
|
||||
# MTP shares lm_head with target model
|
||||
assert proposer.model.lm_head == target_model.lm_head
|
||||
# MTP shares embed_tokens with target model
|
||||
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1])
|
||||
def test_mtp_propose(num_speculative_tokens, monkeypatch):
|
||||
"""Test that MTP's forward method returns hidden states directly"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
batch_size = 2
|
||||
seq_lens = [5, 3]
|
||||
total_tokens = sum(seq_lens)
|
||||
vocab_size = 100
|
||||
|
||||
proposer = _create_mtp_proposer(num_speculative_tokens)
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
# Mock the MTP model to verify it returns hidden states directly
|
||||
model_mock = mock.MagicMock()
|
||||
|
||||
# MTP returns hidden states directly
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.return_value = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
else:
|
||||
# Multiple forward passes for multi-token speculation
|
||||
forward_returns = []
|
||||
for i in range(num_speculative_tokens):
|
||||
if i == 0:
|
||||
h_states = torch.zeros(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
else:
|
||||
h_states = torch.zeros(batch_size, hidden_size, device=device)
|
||||
forward_returns.append(h_states)
|
||||
model_mock.side_effect = forward_returns
|
||||
|
||||
# Mock compute_logits
|
||||
def create_deterministic_logits(batch_size, vocab_size, token_offset):
|
||||
logits = torch.full((batch_size, vocab_size), -100.0, device=device)
|
||||
logits[:, token_offset] = 100.0
|
||||
return logits
|
||||
|
||||
if num_speculative_tokens == 1:
|
||||
model_mock.compute_logits.return_value = create_deterministic_logits(
|
||||
batch_size, vocab_size, 42)
|
||||
else:
|
||||
logits_returns = [
|
||||
create_deterministic_logits(batch_size, vocab_size, 42 + i)
|
||||
for i in range(num_speculative_tokens)
|
||||
]
|
||||
model_mock.compute_logits.side_effect = logits_returns
|
||||
|
||||
proposer.model = model_mock
|
||||
proposer.attn_layer_names = ["layer.0"]
|
||||
|
||||
# Prepare inputs
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=seq_lens)
|
||||
common_attn_metadata = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
|
||||
target_token_ids = torch.randint(0,
|
||||
vocab_size, (total_tokens, ),
|
||||
device=device)
|
||||
target_positions = torch.cat([
|
||||
torch.arange(seq_lens[0], device=device),
|
||||
torch.arange(seq_lens[1], device=device)
|
||||
])
|
||||
target_hidden_states = torch.randn(total_tokens,
|
||||
hidden_size,
|
||||
device=device)
|
||||
next_token_ids = torch.randint(0,
|
||||
vocab_size, (batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
# Setup attention metadata
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
vllm_config=proposer.vllm_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
proposer.runner = mock.MagicMock()
|
||||
proposer.attn_metadata_builder = attn_metadata_builder
|
||||
|
||||
# Run propose
|
||||
result = proposer.propose(target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
# Verify the model was called correctly
|
||||
assert model_mock.called
|
||||
# Verify output shape
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
|
@ -1678,6 +1678,15 @@ def cp_gather_cache(src_cache: torch.Tensor,
|
||||
cu_seq_lens, batch_size, seq_starts)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache(k: torch.Tensor, kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
kv_cache_dtype: str) -> None:
|
||||
torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping,
|
||||
quant_block_size,
|
||||
kv_cache_dtype)
|
||||
|
||||
|
||||
def get_device_attribute(attribute: int, device: int) -> int:
|
||||
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
|
||||
|
||||
|
@ -70,6 +70,7 @@ class AttentionBackend(ABC):
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -95,6 +95,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sparse: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
@ -155,6 +156,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = use_sparse
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -187,7 +189,8 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
has_sink=self.has_sink)
|
||||
has_sink=self.has_sink,
|
||||
use_sparse=use_sparse)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from typing import ClassVar, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig, QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
||||
subclass_attention_backend)
|
||||
AttentionCGSupport, CommonAttentionMetadata,
|
||||
make_local_attention_virtual_batches, subclass_attention_backend)
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
@ -28,6 +28,8 @@ def create_chunked_local_attention_backend(
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
|
@ -138,3 +138,208 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
|
||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||
out = cp_group.reduce_scatter(out, dim=1)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _pack_seq_kernel(
|
||||
x_ptr, # [N, D]
|
||||
out_ptr, # [B, Lmax, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
N: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
PAD_VALUE: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# Compute start index and sequence length from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
|
||||
# compute input row indices for valid (b, t)
|
||||
in_row = in_start + off_t
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# Pointers
|
||||
# x_ptr: row-major [N, D]
|
||||
x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [B, Lmax, D]
|
||||
out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:,
|
||||
None] * D + off_d[None, :]
|
||||
|
||||
# Initialize with PAD (cast will occur as needed based on out_ptr dtype)
|
||||
d_mask = off_d[None, :] < D
|
||||
pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32)
|
||||
tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask)
|
||||
|
||||
# Load & write only where within seq_len
|
||||
x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def pack_seq_triton(x: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
pad_value: float = -float('inf'),
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Pack sequences of different lengths into a batched tensor.
|
||||
|
||||
Args:
|
||||
x: [N, ...] - input tensor where N is total number of tokens
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
pad_value: value to use for padding
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
packed: [B, Lmax, ...] - packed tensor
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (N, -1)
|
||||
original_shape = x.shape
|
||||
if len(original_shape) > 2:
|
||||
N = original_shape[0]
|
||||
x_reshaped = x.reshape(N, -1)
|
||||
D = x_reshaped.shape[1]
|
||||
else:
|
||||
N, D = x.shape
|
||||
x_reshaped = x
|
||||
|
||||
B = lengths.numel()
|
||||
Lmax = int(lengths.max().item())
|
||||
|
||||
# Starts are computed inside the kernel from lengths
|
||||
|
||||
out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_pack_seq_kernel[grid](x_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
N,
|
||||
D,
|
||||
Lmax,
|
||||
PAD_VALUE=float(pad_value),
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 2:
|
||||
output_shape = (B, Lmax) + original_shape[1:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _unpack_seq_triton_kernel(
|
||||
packed_ptr, # [B, Lmax, D]
|
||||
out_ptr, # [N, D]
|
||||
lengths_ptr, # *i32, [B]
|
||||
B: tl.constexpr,
|
||||
Lmax: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr, # timesteps per program
|
||||
BLOCK_D: tl.constexpr # features per program
|
||||
):
|
||||
pid_b = tl.program_id(0) # batch id
|
||||
pid_t = tl.program_id(1) # block over time dimension
|
||||
pid_d = tl.program_id(2) # block over feature dimension
|
||||
off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T]
|
||||
off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D]
|
||||
|
||||
# bounds: compute start from cumulative lengths
|
||||
in_start = 0
|
||||
for i in range(pid_b):
|
||||
in_start += tl.load(lengths_ptr + i)
|
||||
seq_len = tl.load(lengths_ptr + pid_b)
|
||||
|
||||
# valid time positions for this block
|
||||
t_mask = off_t < Lmax
|
||||
valid_row = (off_t < seq_len) & t_mask
|
||||
|
||||
# compute output row indices for valid (b, t)
|
||||
out_row = in_start + off_t
|
||||
|
||||
# Pointers
|
||||
# packed_ptr: row-major [B, Lmax, D]
|
||||
packed_row_ptr = packed_ptr + (pid_b * Lmax +
|
||||
off_t)[:, None] * D + off_d[None, :]
|
||||
|
||||
# out_ptr: row-major [N, D]
|
||||
out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :]
|
||||
|
||||
# Load from packed tensor and store to output
|
||||
d_mask = off_d[None, :] < D
|
||||
packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask)
|
||||
tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask)
|
||||
|
||||
|
||||
def unpack_seq_triton(packed_tensor: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
block_t: int = 64,
|
||||
block_d: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Unpack a packed decode query tensor back to the original format.
|
||||
Efficient Triton implementation.
|
||||
|
||||
Args:
|
||||
packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton
|
||||
lengths: [B] - sequence lengths for each batch
|
||||
block_t: block size for time dimension
|
||||
block_d: block size for feature dimension
|
||||
|
||||
Returns:
|
||||
unpacked_tensor: [N, ...] where N = sum(lengths)
|
||||
"""
|
||||
|
||||
# Handle multi-dimensional input by reshaping to (B, Lmax, -1)
|
||||
original_shape = packed_tensor.shape
|
||||
if len(original_shape) > 3:
|
||||
B, Lmax = original_shape[:2]
|
||||
packed_reshaped = packed_tensor.reshape(B, Lmax, -1)
|
||||
D = packed_reshaped.shape[2]
|
||||
else:
|
||||
B, Lmax, D = packed_tensor.shape
|
||||
packed_reshaped = packed_tensor
|
||||
|
||||
# Calculate total number of elements
|
||||
N = int(lengths.sum().item())
|
||||
|
||||
out = torch.empty((N, D),
|
||||
device=packed_tensor.device,
|
||||
dtype=packed_tensor.dtype)
|
||||
|
||||
grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d))
|
||||
_unpack_seq_triton_kernel[grid](packed_reshaped,
|
||||
out,
|
||||
lengths.int(),
|
||||
B,
|
||||
Lmax,
|
||||
D,
|
||||
BLOCK_T=block_t,
|
||||
BLOCK_D=block_d,
|
||||
num_warps=4,
|
||||
num_stages=2)
|
||||
|
||||
# Reshape output back to original dimensions (except first dimension)
|
||||
if len(original_shape) > 3:
|
||||
output_shape = (N, ) + original_shape[2:]
|
||||
out = out.reshape(output_shape)
|
||||
|
||||
return out
|
||||
|
@ -19,6 +19,15 @@ if current_platform.is_cuda():
|
||||
else:
|
||||
_flashmla_C_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
try:
|
||||
import vllm._flashmla_extension_C # noqa: F401
|
||||
_flashmla_extension_C_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
else:
|
||||
_flashmla_extension_C_AVAILABLE = False
|
||||
|
||||
|
||||
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
@ -37,24 +46,34 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
|
||||
|
||||
|
||||
def get_mla_metadata(
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_heads_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
cache_seqlens: torch.Tensor,
|
||||
num_q_tokens_per_head_k: int,
|
||||
num_heads_k: int,
|
||||
num_heads_q: Optional[int] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
cache_seqlens: (batch_size), dtype torch.int32.
|
||||
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
|
||||
num_heads_k: num_heads_k.
|
||||
- cache_seqlens: (batch_size), dtype torch.int32.
|
||||
- num_q_tokens_per_head_k:
|
||||
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
|
||||
- num_heads_k: The number of k heads.
|
||||
- num_heads_q:
|
||||
The number of q heads.
|
||||
This argument is optional when sparse attention is not enabled
|
||||
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
|
||||
- topk: If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array
|
||||
passed to `flash_mla_with_kvcache_sm90` will be attended to.
|
||||
|
||||
Return:
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
dtype torch.int32.
|
||||
num_splits: (batch_size + 1), dtype torch.int32.
|
||||
Returns:
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens,
|
||||
num_heads_per_head_k,
|
||||
num_heads_k)
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q,
|
||||
is_fp8_kvcache, topk)
|
||||
|
||||
|
||||
def flash_mla_with_kvcache(
|
||||
@ -69,45 +88,95 @@ def flash_mla_with_kvcache(
|
||||
causal: bool = False,
|
||||
descale_q: Optional[torch.Tensor] = None,
|
||||
descale_k: Optional[torch.Tensor] = None,
|
||||
is_fp8_kvcache: bool = False,
|
||||
indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
cache_seqlens: (batch_size), torch.int32.
|
||||
head_dim_v: Head_dim of v.
|
||||
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize),
|
||||
torch.int32, return by get_mla_metadata.
|
||||
num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
|
||||
softmax_scale: float. The scaling of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
causal: bool. Whether to apply causal attention mask.
|
||||
descale_q: (batch_size), torch.float32. Descaling factors for Q.
|
||||
descale_k: (batch_size), torch.float32. Descaling factors for K.
|
||||
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
|
||||
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
|
||||
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
|
||||
- cache_seqlens: (batch_size), torch.int32.
|
||||
- head_dim_v: Head dimension of v.
|
||||
- tile_scheduler_metadata:
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
|
||||
returned by get_mla_metadata.
|
||||
- num_splits:
|
||||
(batch_size + 1), torch.int32, returned by get_mla_metadata.
|
||||
- softmax_scale: float.
|
||||
The scale of QK^T before applying softmax.
|
||||
Default to 1 / sqrt(head_dim).
|
||||
- causal: bool. Whether to apply causal attention mask.
|
||||
- descale_q: (batch_size),
|
||||
torch.float32. Descaling factors for Q, used for fp8 quantization.
|
||||
- descale_k: (batch_size),
|
||||
torch.float32. Descaling factors for K, used for fp8 quantization.
|
||||
- is_fp8_kvcache: bool.
|
||||
Whether the k_cache and v_cache are in fp8 format.
|
||||
For the format of FP8 KV cache, please refer to README.md
|
||||
- indices: (batch_size, seq_len_q, topk), torch.int32.
|
||||
If not None, sparse attention will be enabled,
|
||||
and only tokens in the `indices` array will be attended to.
|
||||
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
|
||||
For details about how to set up `indices`, please refer to README.md.
|
||||
|
||||
Return:
|
||||
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
Returns:
|
||||
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
|
||||
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
|
||||
"""
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1]**(-0.5)
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q,
|
||||
k_cache,
|
||||
head_dim_v,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
softmax_scale,
|
||||
causal,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
descale_q,
|
||||
descale_k,
|
||||
)
|
||||
if indices is not None:
|
||||
# NOTE (zyongye): sparse attention is also causal
|
||||
# since it only attend to the tokens before
|
||||
# but here `causal` should not be specified
|
||||
assert not causal, \
|
||||
"causal must be `false` if sparse attention is enabled."
|
||||
assert (descale_q is None) == (
|
||||
descale_k is None
|
||||
), "descale_q and descale_k should be both None or both not None"
|
||||
|
||||
# Note(hc): need revisit when we support DCP with decode query_len > 1.
|
||||
return out.squeeze(1), softmax_lse.squeeze(-1)
|
||||
if indices is None and q.element_size() == 1:
|
||||
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)
|
||||
else:
|
||||
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
|
||||
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
|
||||
causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache,
|
||||
indices)
|
||||
return out, softmax_lse
|
||||
|
||||
|
||||
def flash_mla_sparse_prefill(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
sm_scale: float,
|
||||
d_v: int = 512,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sparse attention prefill kernel
|
||||
|
||||
Args:
|
||||
- q: [s_q, h_q, d_qk], bfloat16
|
||||
- kv: [s_kv, h_kv, d_qk], bfloat16
|
||||
- indices: [s_q, h_kv, topk], int32.
|
||||
Invalid indices should be set to -1 or numbers >= s_kv
|
||||
- sm_scale: float
|
||||
- d_v: The dimension of value vectors. Can only be 512
|
||||
|
||||
Returns:
|
||||
- (output, max_logits, lse)
|
||||
About the definition of output,
|
||||
max_logits and lse, please refer to README.md
|
||||
- output: [s_q, h_q, d_v], bfloat16
|
||||
- max_logits: [s_q, h_q], float
|
||||
- lse: [s_q, h_q], float, 2-based log-sum-exp
|
||||
"""
|
||||
results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices,
|
||||
sm_scale, d_v)
|
||||
return results
|
||||
|
||||
|
||||
#
|
||||
|
@ -50,6 +50,7 @@ class PagedAttention:
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size * num_kv_heads * head_size)
|
||||
|
||||
|
@ -144,6 +144,7 @@ def get_attn_backend(
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
@ -158,6 +159,7 @@ def get_attn_backend(
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
has_sink=has_sink,
|
||||
use_sparse=use_sparse,
|
||||
)
|
||||
|
||||
|
||||
@ -170,6 +172,7 @@ def _cached_get_attn_backend(
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
@ -203,7 +206,7 @@ def _cached_get_attn_backend(
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla, has_sink)
|
||||
use_mla, has_sink, use_sparse)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
|
@ -22,7 +22,8 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BlockSize = Literal[1, 8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2",
|
||||
"fp8_inc"]
|
||||
MambaDType = Literal["auto", "float32"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||
|
||||
@ -52,7 +53,11 @@ class CacheConfig:
|
||||
cache_dtype: CacheDType = "auto"
|
||||
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc)."""
|
||||
fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).
|
||||
Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use
|
||||
bfloat16 instead, this is an invalid option for models that do not default
|
||||
to fp8.
|
||||
"""
|
||||
is_attention_free: bool = False
|
||||
"""Whether the model is attention-free. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
@ -171,11 +176,12 @@ class CacheConfig:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype in get_args(CacheDType):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor.")
|
||||
if self.cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor.")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
|
||||
|
@ -360,6 +360,7 @@ class CompilationConfig:
|
||||
"vllm.linear_attention",
|
||||
"vllm.plamo2_mamba_mixer",
|
||||
"vllm.gdn_attention",
|
||||
"vllm.sparse_attn_indexer",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
|
@ -137,6 +137,9 @@ class ModelConfig:
|
||||
"""Allowing API requests to read local images or videos from directories
|
||||
specified by the server file system. This is a security risk. Should only
|
||||
be enabled in trusted environments."""
|
||||
allowed_media_domains: Optional[list[str]] = None
|
||||
"""If set, only media URLs that belong to this domain can be used for
|
||||
multi-modal inputs. """
|
||||
revision: Optional[str] = None
|
||||
"""The specific model version to use. It can be a branch name, a tag name,
|
||||
or a commit id. If unspecified, will use the default version."""
|
||||
@ -1074,14 +1077,14 @@ class ModelConfig:
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_v32', 'deepseek_mtp',
|
||||
'kimi_k2', 'longcat_flash'):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == 'eagle':
|
||||
# if the model is an EAGLE module, check for the
|
||||
# underlying architecture
|
||||
return self.hf_text_config.model.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3') \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
|
||||
and self.hf_text_config.kv_lora_rank is not None
|
||||
return False
|
||||
|
||||
|
@ -279,6 +279,24 @@ class ParallelConfig:
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
@property
|
||||
def use_sequence_parallel_moe(self) -> bool:
|
||||
return (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("allgather_reducescatter", "naive",
|
||||
"deepep_high_throughput", "deepep_low_latency")
|
||||
and self.enable_expert_parallel
|
||||
and self.tensor_parallel_size > 1
|
||||
and self.data_parallel_size > 1)
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: ProcessGroup,
|
||||
has_unfinished: bool) -> bool:
|
||||
|
@ -32,14 +32,17 @@ logger = init_logger(__name__)
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp",
|
||||
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
|
||||
"longcat_flash_mtp"]
|
||||
"longcat_flash_mtp", "mtp"]
|
||||
MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding."""
|
||||
|
||||
enforce_eager: Optional[bool] = None
|
||||
"""Override the default enforce_eager from model_config"""
|
||||
# General speculative decoding control
|
||||
num_speculative_tokens: SkipValidation[int] = None # type: ignore
|
||||
"""The number of speculative tokens, if provided. It will default to the
|
||||
@ -143,7 +146,7 @@ class SpeculativeConfig:
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type == "deepseek_v3":
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
@ -207,11 +210,21 @@ class SpeculativeConfig:
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in MTP_MODEL_TYPES:
|
||||
logger.warning("method `%s` is deprecated and replaced with mtp.",
|
||||
self.method)
|
||||
self.method = "mtp"
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
if (self.target_model_config
|
||||
and self.target_model_config.hf_text_config.model_type
|
||||
in ("deepseek_v3", "mimo", "ernie4_5_moe", "qwen3_next")):
|
||||
if self.method == "mtp":
|
||||
assert (
|
||||
self.target_model_config
|
||||
is not None), "target_model_config must be present for mtp"
|
||||
if self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v32":
|
||||
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
|
||||
# remove this when the issue is fixed.
|
||||
self.enforce_eager = True
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
@ -281,6 +294,8 @@ class SpeculativeConfig:
|
||||
trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.
|
||||
allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
@ -312,31 +327,13 @@ class SpeculativeConfig:
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
in MTP_MODEL_TYPES):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"qwen3_next_mtp"):
|
||||
self.method = "qwen3_next_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Qwen3Next MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
"Enabling num_speculative_tokens > 1 will run" \
|
||||
"multiple times of forward on same MTP layer" \
|
||||
",which may result in lower acceptance rate" \
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
@ -353,7 +350,7 @@ class SpeculativeConfig:
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
"eagle, or mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
@ -562,8 +559,7 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
|
||||
"qwen3_next_mtp", "longcat_flash_mtp")
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
|
@ -54,6 +54,7 @@ class HTTPConnection:
|
||||
stream: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
allow_redirects: bool = True,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
@ -63,7 +64,8 @@ class HTTPConnection:
|
||||
return client.get(url,
|
||||
headers=self._headers(**extra_headers),
|
||||
stream=stream,
|
||||
timeout=timeout)
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects)
|
||||
|
||||
async def get_async_response(
|
||||
self,
|
||||
@ -71,6 +73,7 @@ class HTTPConnection:
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
allow_redirects: bool = True,
|
||||
):
|
||||
self._validate_http_url(url)
|
||||
|
||||
@ -79,10 +82,17 @@ class HTTPConnection:
|
||||
|
||||
return client.get(url,
|
||||
headers=self._headers(**extra_headers),
|
||||
timeout=timeout)
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects)
|
||||
|
||||
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
|
||||
with self.get_response(url, timeout=timeout) as r:
|
||||
def get_bytes(self,
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
allow_redirects: bool = True) -> bytes:
|
||||
with self.get_response(url,
|
||||
timeout=timeout,
|
||||
allow_redirects=allow_redirects) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return r.content
|
||||
@ -92,8 +102,10 @@ class HTTPConnection:
|
||||
url: str,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
allow_redirects: bool = True,
|
||||
) -> bytes:
|
||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
||||
async with await self.get_async_response(
|
||||
url, timeout=timeout, allow_redirects=allow_redirects) as r:
|
||||
r.raise_for_status()
|
||||
|
||||
return await r.read()
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
@ -34,41 +34,60 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
cu_tokens_across_sp_cpu: torch.Tensor,
|
||||
is_sequence_parallel: bool) -> torch.Tensor:
|
||||
assert (len(x.shape) == 2)
|
||||
buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)),
|
||||
buffer = torch.empty((cu_tokens_across_sp_cpu[-1], x.size(1)),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
world_size = (self.world_size
|
||||
if is_sequence_parallel else self.dp_world_size)
|
||||
|
||||
start = 0 if rank == 0 else cu_tokens_across_sp_cpu[rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[rank]
|
||||
buffer[start:end, :].copy_(x)
|
||||
for idx in range(self.dp_world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1]
|
||||
end = cu_tokens_across_dp_cpu[idx]
|
||||
self.dp_group.broadcast(buffer[start:end, :], idx)
|
||||
for idx in range(world_size):
|
||||
start = 0 if idx == 0 else cu_tokens_across_sp_cpu[idx - 1]
|
||||
end = cu_tokens_across_sp_cpu[idx]
|
||||
get_ep_group().broadcast(buffer[start:end, :], idx)
|
||||
|
||||
return buffer
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_sp_cpu,
|
||||
is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
|
||||
ep_rank = self.rank if is_sequence_parallel else self.dp_rank
|
||||
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)
|
||||
|
||||
start = 0 if ep_rank == 0 else cu_tokens_across_sp_cpu[ep_rank - 1]
|
||||
end = cu_tokens_across_sp_cpu[ep_rank]
|
||||
|
||||
all_hidden_states = get_ep_group().all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@ -84,29 +103,40 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
def __init__(self, cpu_group):
|
||||
super().__init__(cpu_group)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
hidden_states, router_logits = dist_group.all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Reduce-scatter hidden_states across all dp ranks.
|
||||
"""
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
hidden_states = dist_group.reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@ -148,11 +178,17 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
kwargs, pplx.AllToAll.internode
|
||||
if self.internode else pplx.AllToAll.intranode)
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -184,11 +220,17 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
def get_handle(self, kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
def dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -395,4 +437,4 @@ class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
self.initialized = False
|
||||
|
@ -28,6 +28,8 @@ class Cache:
|
||||
|
||||
|
||||
class All2AllManagerBase:
|
||||
rank: int
|
||||
world_size: int
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
self.cpu_group = cpu_group
|
||||
@ -40,6 +42,7 @@ class All2AllManagerBase:
|
||||
# all2all lives in ep group, which is merged from dp and tp group
|
||||
self.dp_group = get_dp_group()
|
||||
self.tp_group = get_tp_group()
|
||||
|
||||
# no self.ep_group since self.ep_group is still in construction
|
||||
# when we create this object
|
||||
self.dp_rank = self.dp_group.rank_in_group
|
||||
@ -60,17 +63,21 @@ class All2AllManagerBase:
|
||||
# and reuse it for the same config.
|
||||
raise NotImplementedError
|
||||
|
||||
def dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
pass
|
||||
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return None # None means it could use the whole GPU
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
raise NotImplementedError
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False):
|
||||
raise NotImplementedError
|
||||
|
||||
def destroy(self):
|
||||
@ -267,15 +274,20 @@ class DeviceCommunicatorBase:
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Dispatch the hidden states and router logits to the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
"""
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Combine the hidden states and router logits from the appropriate device.
|
||||
This is a no-op in the base class.
|
||||
|
@ -39,10 +39,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
||||
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
|
||||
|
||||
# ep does not use pynccl
|
||||
use_pynccl = "ep" not in unique_name
|
||||
|
||||
self.use_pynccl = use_pynccl
|
||||
self.use_custom_allreduce = use_custom_allreduce
|
||||
self.use_torch_symm_mem = use_torch_symm_mem
|
||||
|
||||
@ -57,7 +53,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
SymmMemCommunicator)
|
||||
|
||||
self.pynccl_comm: Optional[PyNcclCommunicator] = None
|
||||
if use_pynccl and self.world_size > 1:
|
||||
if self.world_size > 1:
|
||||
self.pynccl_comm = PyNcclCommunicator(
|
||||
group=self.cpu_group,
|
||||
device=self.device,
|
||||
@ -308,14 +304,20 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
return output_list
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
return hidden_states
|
||||
|
@ -75,14 +75,20 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
dist.broadcast(input_, src=src, group=self.device_group)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits)
|
||||
hidden_states, router_logits, is_sequence_parallel)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states = self.all2all_manager.combine(hidden_states)
|
||||
hidden_states = self.all2all_manager.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
return hidden_states
|
||||
|
@ -84,7 +84,7 @@ class NixlAgentMetadata(
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
block_len: int
|
||||
block_lens: list[int]
|
||||
attn_backend_name: str
|
||||
kv_cache_layout: str
|
||||
|
||||
@ -105,6 +105,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
|
||||
self.reqs_to_send: dict[ReqId, float] = {}
|
||||
self.reqs_in_batch: set[ReqId] = set()
|
||||
|
||||
def add_new_req(
|
||||
self,
|
||||
@ -278,6 +279,7 @@ class NixlConnectorScheduler:
|
||||
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
self._reqs_in_batch: set[ReqId] = set()
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@ -324,6 +326,9 @@ class NixlConnectorScheduler:
|
||||
|
||||
if not params:
|
||||
return
|
||||
|
||||
if params.get("do_remote_decode"):
|
||||
self._reqs_in_batch.add(request.request_id)
|
||||
if self.use_host_buffer and params.get("do_remote_decode"):
|
||||
# NOTE: when accelerator is not directly supported by Nixl,
|
||||
# prefilled blocks need to be saved to host memory before transfer.
|
||||
@ -373,6 +378,8 @@ class NixlConnectorScheduler:
|
||||
request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
load_remote_cache=True,
|
||||
save_to_host=False,
|
||||
)
|
||||
|
||||
for req_id, (req, block_ids) in self._reqs_need_save.items():
|
||||
@ -386,10 +393,12 @@ class NixlConnectorScheduler:
|
||||
)
|
||||
|
||||
meta.reqs_to_send = self._reqs_need_send
|
||||
meta.reqs_in_batch = self._reqs_in_batch
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
self._reqs_need_save.clear()
|
||||
self._reqs_in_batch = set()
|
||||
self._reqs_need_send = {}
|
||||
|
||||
return meta
|
||||
@ -465,8 +474,11 @@ class NixlConnectorWorker:
|
||||
"backends", ["UCX"])
|
||||
# Agent.
|
||||
non_ucx_backends = [b for b in self.nixl_backends if b != "UCX"]
|
||||
config = nixl_agent_config(backends=self.nixl_backends) if len(
|
||||
non_ucx_backends) > 0 and nixl_agent_config is not None else None
|
||||
if nixl_agent_config is None:
|
||||
config = None
|
||||
else:
|
||||
config = nixl_agent_config(backends=self.nixl_backends) if len(
|
||||
non_ucx_backends) > 0 else nixl_agent_config(num_threads=8)
|
||||
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
@ -546,6 +558,8 @@ class NixlConnectorWorker:
|
||||
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
|
||||
# Track the expiration time of requests that are waiting to be sent.
|
||||
self._reqs_to_send: dict[ReqId, float] = {}
|
||||
# Set of requests that have been part of a batch, regardless of status.
|
||||
self._reqs_to_process: set[ReqId] = set()
|
||||
|
||||
# Background thread for handling new handshake requests.
|
||||
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
|
||||
@ -752,6 +766,9 @@ class NixlConnectorWorker:
|
||||
split_k_and_v = not (self.use_mla or self._use_pallas
|
||||
or self._use_flashinfer)
|
||||
tensor_size_bytes = None
|
||||
# Enable different block lengths for different layers when MLA is used.
|
||||
self.block_len_per_layer = list[int]()
|
||||
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
|
||||
for layer_name, cache_or_caches in xfer_buffers.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [
|
||||
cache_or_caches
|
||||
@ -769,10 +786,25 @@ class NixlConnectorWorker:
|
||||
tensor_size_bytes = curr_tensor_size_bytes
|
||||
self.num_blocks = cache.shape[0]
|
||||
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
assert cache.shape[0] == self.num_blocks, \
|
||||
"All kv cache tensors must have the same number of blocks"
|
||||
|
||||
self.block_len_per_layer.append(curr_tensor_size_bytes //
|
||||
self.num_blocks)
|
||||
self.slot_size_per_layer.append(self.block_len_per_layer[-1] //
|
||||
self.block_size)
|
||||
|
||||
if not self.use_mla:
|
||||
# Different kv cache shape is not supported by HeteroTP
|
||||
assert tensor_size_bytes == curr_tensor_size_bytes, \
|
||||
"All kv cache tensors must have the same size"
|
||||
caches_data.append(
|
||||
(base_addr, tensor_size_bytes, self.tp_rank, ""))
|
||||
(base_addr, curr_tensor_size_bytes, self.tp_rank, ""))
|
||||
|
||||
logger.debug("Different block lengths collected: %s",
|
||||
set(self.block_len_per_layer))
|
||||
assert len(self.block_len_per_layer) == len(seen_base_addresses)
|
||||
assert self.num_blocks != 0
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
|
||||
self.num_regions = len(caches_data)
|
||||
@ -785,16 +817,12 @@ class NixlConnectorWorker:
|
||||
logger.debug("Done registering descs")
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
assert tensor_size_bytes is not None
|
||||
assert self.num_blocks != 0
|
||||
assert tensor_size_bytes % self.num_blocks == 0
|
||||
self.block_len = tensor_size_bytes // self.num_blocks
|
||||
self.slot_size_bytes = self.block_len // self.block_size
|
||||
self.device_kv_caches = kv_caches
|
||||
self.dst_num_blocks[self.engine_id] = self.num_blocks
|
||||
if self._use_flashinfer:
|
||||
assert self.slot_size_bytes % 2 == 0
|
||||
self.slot_size_bytes /= 2
|
||||
for i in range(len(self.slot_size_per_layer)):
|
||||
assert self.slot_size_per_layer[i] % 2 == 0
|
||||
self.slot_size_per_layer[i] //= 2
|
||||
|
||||
# NOTE (NickLucche) When FlashInfer is used, memory is registered
|
||||
# with joint KV for each block. This minimizes the overhead in
|
||||
@ -804,17 +832,17 @@ class NixlConnectorWorker:
|
||||
# of 'virtual' regions here and halve `block_len` below.
|
||||
self.num_regions *= 2
|
||||
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
# Register local/src descr for NIXL xfer.
|
||||
blocks_data = []
|
||||
for base_addr in seen_base_addresses:
|
||||
for i, base_addr in enumerate(seen_base_addresses):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
# NOTE With heter-TP, more blocks are prepared than what are
|
||||
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
||||
# could create fewer, but then _get_block_descs_ids needs to
|
||||
# select agent_meta.num_blocks instead of self.num_blocks for
|
||||
# local descr, and that makes handling regular flow less clean.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, kv_block_len, self.tp_rank))
|
||||
@ -824,7 +852,7 @@ class NixlConnectorWorker:
|
||||
# descs ordering. This is needed for selecting contiguous heads
|
||||
# when split across TP ranks.
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
block_offset = block_id * self.block_len_per_layer[i]
|
||||
addr = base_addr + block_offset
|
||||
# Register addresses for V cache (K registered first).
|
||||
v_addr = addr + kv_block_len
|
||||
@ -864,7 +892,7 @@ class NixlConnectorWorker:
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
block_len=self.block_len,
|
||||
block_lens=self.block_len_per_layer,
|
||||
attn_backend_name=self.backend_name,
|
||||
kv_cache_layout=self.kv_cache_layout)
|
||||
ready_event = threading.Event()
|
||||
@ -889,7 +917,7 @@ class NixlConnectorWorker:
|
||||
The latter, assuming D.world_size > P.world_size, requires that two or
|
||||
more local TP worker share the xfer from a single TP worker.
|
||||
|
||||
Here's an example:
|
||||
Here's an example (non-MLA case):
|
||||
|
||||
rank_offset p_remote_tp_rank
|
||||
(kv split no)
|
||||
@ -945,14 +973,20 @@ class NixlConnectorWorker:
|
||||
total_num_kv_heads = self.model_config.get_total_num_kv_heads()
|
||||
is_kv_replicated = self._tp_size[engine_id] // total_num_kv_heads >= 1
|
||||
|
||||
remote_block_len = nixl_agent_meta.block_lens[0]
|
||||
if self.use_mla or is_kv_replicated:
|
||||
# With MLA the only difference is in the number of blocks.
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes)
|
||||
assert self.block_len == nixl_agent_meta.block_len
|
||||
# With replicated KV cache, only the number of blocks can differ.
|
||||
assert self.block_len_per_layer == nixl_agent_meta.block_lens, \
|
||||
"KV cache sizes must match between P and D when replicated"
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0])
|
||||
else:
|
||||
remote_block_size = nixl_agent_meta.block_len // (
|
||||
self.slot_size_bytes * tp_ratio)
|
||||
# When MLA is not used, this is a list of the same block length
|
||||
for block_len in nixl_agent_meta.block_lens:
|
||||
assert block_len == remote_block_len, \
|
||||
"All remote layers must have the same block size"
|
||||
remote_block_size = remote_block_len // (
|
||||
self.slot_size_per_layer[0] * tp_ratio)
|
||||
if self._use_flashinfer:
|
||||
# With flashinfer, KV are sent in the same message.
|
||||
remote_block_size //= 2
|
||||
@ -963,14 +997,14 @@ class NixlConnectorWorker:
|
||||
raise ValueError(
|
||||
"Heterogeneous TP is not supported on XPU")
|
||||
|
||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||
assert remote_block_len == self.block_len_per_layer[0] * tp_ratio, (
|
||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
||||
)
|
||||
|
||||
assert self.block_size == remote_block_size, (
|
||||
"Remote P worker with different block size is not supported "
|
||||
f"{self.block_size=} {remote_block_size=}")
|
||||
"Remote P worker with different page/block size is not supported "
|
||||
f"{self.block_size=}, {remote_block_size=}")
|
||||
|
||||
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
||||
if engine_id in self.dst_num_blocks:
|
||||
@ -985,13 +1019,16 @@ class NixlConnectorWorker:
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
kv_block_len = self.get_backend_aware_kv_block_len()
|
||||
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
|
||||
assert len(nixl_agent_meta.kv_caches_base_addr) == len(
|
||||
self.block_len_per_layer)
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
|
||||
kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
|
||||
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
@ -1002,9 +1039,9 @@ class NixlConnectorWorker:
|
||||
if self._use_flashinfer:
|
||||
# With FlashInfer index V separately to allow head splitting.
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
block_offset = block_id * nixl_agent_meta.block_lens[i]
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
v_addr = addr + nixl_agent_meta.block_len // 2
|
||||
v_addr = addr + nixl_agent_meta.block_lens[i] // 2
|
||||
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
|
||||
|
||||
logger.debug(
|
||||
@ -1082,6 +1119,7 @@ class NixlConnectorWorker:
|
||||
"Releasing expired KV blocks for request %s which were "
|
||||
"retrieved by %d decode worker(s) within %d seconds.", req_id,
|
||||
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
|
||||
self._reqs_to_process.remove(req_id)
|
||||
del self._reqs_to_send[req_id]
|
||||
done_sending.add(req_id)
|
||||
|
||||
@ -1097,7 +1135,8 @@ class NixlConnectorWorker:
|
||||
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
||||
for notif in notifs:
|
||||
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
||||
if req_id not in self._reqs_to_send:
|
||||
if (req_id not in self._reqs_to_send
|
||||
and req_id not in self._reqs_to_process):
|
||||
logger.error(
|
||||
"Potentially invalid KV blocks for "
|
||||
"unrecognized request %s were retrieved by "
|
||||
@ -1110,7 +1149,8 @@ class NixlConnectorWorker:
|
||||
tp_ratio):
|
||||
notified_req_ids.add(req_id)
|
||||
del self.consumer_notification_counts_by_req[req_id]
|
||||
del self._reqs_to_send[req_id]
|
||||
self._reqs_to_process.remove(req_id)
|
||||
self._reqs_to_send.pop(req_id, None)
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(
|
||||
@ -1171,8 +1211,19 @@ class NixlConnectorWorker:
|
||||
while not self._ready_requests.empty():
|
||||
self._read_blocks_for_req(*self._ready_requests.get_nowait())
|
||||
|
||||
# Keep around the requests that have been part of a batch. This is
|
||||
# needed because async scheduling pushes the misalignment between the
|
||||
# moment in which requests expiration is set (P side) and the moment in
|
||||
# which blocks are read from D. As P can now more easily lag behind D
|
||||
# while processing the next batch, we make sure to only set an
|
||||
# expiration for requests that have not been read from D yet.
|
||||
for req_id in metadata.reqs_in_batch:
|
||||
self._reqs_to_process.add(req_id)
|
||||
|
||||
# Add to requests that are waiting to be read and track expiration.
|
||||
self._reqs_to_send.update(metadata.reqs_to_send)
|
||||
for req_id, expiration_time in metadata.reqs_to_send.items():
|
||||
if req_id in self._reqs_to_process:
|
||||
self._reqs_to_send[req_id] = expiration_time
|
||||
|
||||
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
|
||||
logger.debug(
|
||||
@ -1317,7 +1368,7 @@ class NixlConnectorWorker:
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
|
||||
def get_backend_aware_kv_block_len(self):
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
@ -1328,9 +1379,9 @@ class NixlConnectorWorker:
|
||||
"""
|
||||
if self._use_flashinfer:
|
||||
# For indexing only half (either just the K or V part).
|
||||
block_len = self.block_len // 2
|
||||
block_len = self.block_len_per_layer[layer_idx] // 2
|
||||
else:
|
||||
block_len = self.block_len
|
||||
block_len = self.block_len_per_layer[layer_idx]
|
||||
return block_len
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
|
||||
|
@ -871,17 +871,24 @@ class GroupCoordinator:
|
||||
model)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(hidden_states,
|
||||
router_logits)
|
||||
router_logits,
|
||||
is_sequence_parallel)
|
||||
else:
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states) -> torch.Tensor:
|
||||
def combine(self,
|
||||
hidden_states,
|
||||
is_sequence_parallel: bool = False) -> torch.Tensor:
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.combine(hidden_states)
|
||||
return self.device_communicator.combine(hidden_states,
|
||||
is_sequence_parallel)
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
@ -297,6 +297,8 @@ class EngineArgs:
|
||||
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
allowed_media_domains: Optional[
|
||||
list[str]] = ModelConfig.allowed_media_domains
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
|
||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||
@ -531,6 +533,8 @@ class EngineArgs:
|
||||
**model_kwargs["hf_config_path"])
|
||||
model_group.add_argument("--allowed-local-media-path",
|
||||
**model_kwargs["allowed_local_media_path"])
|
||||
model_group.add_argument("--allowed-media-domains",
|
||||
**model_kwargs["allowed_media_domains"])
|
||||
model_group.add_argument("--revision", **model_kwargs["revision"])
|
||||
model_group.add_argument("--code-revision",
|
||||
**model_kwargs["code_revision"])
|
||||
@ -997,6 +1001,7 @@ class EngineArgs:
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
allowed_local_media_path=self.allowed_local_media_path,
|
||||
allowed_media_domains=self.allowed_media_domains,
|
||||
dtype=self.dtype,
|
||||
seed=self.seed,
|
||||
revision=self.revision,
|
||||
@ -1481,7 +1486,7 @@ class EngineArgs:
|
||||
raise NotImplementedError(
|
||||
"Draft model speculative decoding is not supported yet. "
|
||||
"Please consider using other speculative decoding methods "
|
||||
"such as ngram, medusa, eagle, or deepseek_mtp.")
|
||||
"such as ngram, medusa, eagle, or mtp.")
|
||||
|
||||
V1_BACKENDS = [
|
||||
"FLASH_ATTN",
|
||||
|
@ -11,7 +11,12 @@ from pathlib import Path
|
||||
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
|
||||
cast)
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import jinja2.meta
|
||||
import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import random_uuid, supports_kw
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -632,6 +637,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@property
|
||||
def allowed_media_domains(self):
|
||||
return self._model_config.allowed_media_domains
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return MULTIMODAL_REGISTRY
|
||||
@ -832,6 +841,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
@ -916,6 +926,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
self._connector = MediaConnector(
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
def parse_image(
|
||||
@ -1548,6 +1559,46 @@ def parse_chat_messages_futures(
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||
# only preserve the parse function used to resolve chat template kwargs
|
||||
class AssistantTracker(jinja2.ext.Extension):
|
||||
tags = {"generation"}
|
||||
|
||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||
lineno = next(parser.stream).lineno
|
||||
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||
call = self.call_method("_generation_support")
|
||||
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||
return call_block.set_lineno(lineno)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
fn_kw = {
|
||||
k for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
|
||||
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||
)
|
||||
parsed_content = env.parse(chat_template)
|
||||
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||
|
||||
# We exclude chat_template from kwargs here, because
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template"}
|
||||
accept_vars = (fn_kw | template_vars) - unexpected_vars
|
||||
return {
|
||||
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
|
||||
}
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
conversation: list[ConversationMessage],
|
||||
@ -1573,12 +1624,17 @@ def apply_hf_chat_template(
|
||||
)
|
||||
|
||||
try:
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=hf_chat_template,
|
||||
chat_template_kwargs=kwargs,
|
||||
)
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=hf_chat_template,
|
||||
tokenize=tokenize,
|
||||
**kwargs,
|
||||
**resolved_kwargs,
|
||||
)
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
|
@ -86,6 +86,8 @@ class LLM:
|
||||
or videos from directories specified by the server file system.
|
||||
This is a security risk. Should only be enabled in trusted
|
||||
environments.
|
||||
allowed_media_domains: If set, only media URLs that belong to this
|
||||
domain can be used for multi-modal inputs.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
@ -169,6 +171,7 @@ class LLM:
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: ModelDType = "auto",
|
||||
quantization: Optional[QuantizationMethods] = None,
|
||||
@ -264,6 +267,7 @@ class LLM:
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
allowed_local_media_path=allowed_local_media_path,
|
||||
allowed_media_domains=allowed_media_domains,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
|
@ -3,12 +3,14 @@
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import multiprocessing
|
||||
import multiprocessing.forkserver as forkserver
|
||||
import os
|
||||
import secrets
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
@ -1252,7 +1254,7 @@ def load_log_config(log_config_file: Optional[str]) -> Optional[dict]:
|
||||
class AuthenticationMiddleware:
|
||||
"""
|
||||
Pure ASGI middleware that authenticates each request by checking
|
||||
if the Authorization header exists and equals "Bearer {api_key}".
|
||||
if the Authorization Bearer token exists and equals anyof "{api_key}".
|
||||
|
||||
Notes
|
||||
-----
|
||||
@ -1263,7 +1265,26 @@ class AuthenticationMiddleware:
|
||||
|
||||
def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
|
||||
self.app = app
|
||||
self.api_tokens = {f"Bearer {token}" for token in tokens}
|
||||
self.api_tokens = [
|
||||
hashlib.sha256(t.encode("utf-8")).digest() for t in tokens
|
||||
]
|
||||
|
||||
def verify_token(self, headers: Headers) -> bool:
|
||||
authorization_header_value = headers.get("Authorization")
|
||||
if not authorization_header_value:
|
||||
return False
|
||||
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
if scheme.lower() != "bearer":
|
||||
return False
|
||||
|
||||
param_hash = hashlib.sha256(param.encode("utf-8")).digest()
|
||||
|
||||
token_match = False
|
||||
for token_hash in self.api_tokens:
|
||||
token_match |= secrets.compare_digest(param_hash, token_hash)
|
||||
|
||||
return token_match
|
||||
|
||||
def __call__(self, scope: Scope, receive: Receive,
|
||||
send: Send) -> Awaitable[None]:
|
||||
@ -1276,8 +1297,7 @@ class AuthenticationMiddleware:
|
||||
url_path = URL(scope=scope).path.removeprefix(root_path)
|
||||
headers = Headers(scope=scope)
|
||||
# Type narrow to satisfy mypy.
|
||||
if url_path.startswith("/v1") and headers.get(
|
||||
"Authorization") not in self.api_tokens:
|
||||
if url_path.startswith("/v1") and not self.verify_token(headers):
|
||||
response = JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return response(scope, receive, send)
|
||||
@ -1696,6 +1716,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
exclude_tools_when_tool_choice_none=args.
|
||||
|
@ -103,9 +103,13 @@ class FrontendArgs:
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
|
||||
"""The format to render message content within a chat template.
|
||||
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to OpenAI
|
||||
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
* "string" will render the content as a string. Example: `"Hello World"`
|
||||
* "openai" will render the content as a list of dictionaries, similar to
|
||||
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
|
||||
trust_request_chat_template: bool = False
|
||||
"""Whether to trust the chat template provided in the request. If False,
|
||||
the server will always use the chat template specified by `--chat-template`
|
||||
or the ones from tokenizer."""
|
||||
response_role: str = "assistant"
|
||||
"""The role name to return if `request.add_generation_prompt=true`."""
|
||||
ssl_keyfile: Optional[str] = None
|
||||
|
@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
trust_request_chat_template: bool = False,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
reasoning_parser: str = "",
|
||||
enable_auto_tools: bool = False,
|
||||
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.response_role = response_role
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
self.enable_log_outputs = enable_log_outputs
|
||||
|
||||
# set up tool use
|
||||
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if not self.use_harmony:
|
||||
# Common case.
|
||||
request_chat_template = request.chat_template
|
||||
chat_template_kwargs = request.chat_template_kwargs
|
||||
if not self.trust_request_chat_template and (
|
||||
request_chat_template is not None or
|
||||
(chat_template_kwargs and
|
||||
chat_template_kwargs.get("chat_template") is not None)):
|
||||
return self.create_error_response(
|
||||
"Chat template is passed with request, but "
|
||||
"--trust-request-chat-template is not set. "
|
||||
"Refused request with untrusted chat template.")
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template=request_chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
|
@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
|
||||
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||
@ -725,6 +726,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||
|
||||
# Whether to allow HTTP redirects when fetching from media URLs.
|
||||
# Default to True
|
||||
"VLLM_MEDIA_URL_ALLOW_REDIRECTS":
|
||||
lambda: bool(int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1"))),
|
||||
|
||||
# Max number of workers for the thread pool handling
|
||||
# media bytes loading. Set to 1 to disable parallel processing.
|
||||
# Default is 8
|
||||
|
@ -49,16 +49,29 @@ class BatchDescriptor(NamedTuple):
|
||||
return BatchDescriptor(self.num_tokens, uniform_decode=False)
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int) -> list[int]:
|
||||
sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) //
|
||||
sequence_parallel_size)
|
||||
|
||||
sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size)
|
||||
return sp_tokens.tolist()
|
||||
|
||||
|
||||
def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor,
|
||||
sequence_parallel_size: int,
|
||||
max_num_tokens: int,
|
||||
chunk_idx: int) -> list[int]:
|
||||
dp_size = len(num_tokens_across_dp_cpu)
|
||||
|
||||
local_size = [-1] * dp_size
|
||||
for i in range(dp_size):
|
||||
dp_tokens = num_tokens_across_dp_cpu[i]
|
||||
sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu,
|
||||
sequence_parallel_size)
|
||||
sp_size = len(sp_tokens)
|
||||
|
||||
local_size = [-1] * sp_size
|
||||
for i in range(sp_size):
|
||||
# Take into account sharding if MoE activation is sequence parallel.
|
||||
local_size[i] = min(max_num_tokens,
|
||||
dp_tokens - (max_num_tokens * chunk_idx))
|
||||
sp_tokens[i] - (max_num_tokens * chunk_idx))
|
||||
if local_size[i] <= 0:
|
||||
local_size[i] = 1 # ensure lockstep even if done
|
||||
return local_size
|
||||
@ -67,7 +80,9 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int],
|
||||
@dataclass
|
||||
class DPMetadata:
|
||||
max_tokens_across_dp_cpu: torch.Tensor
|
||||
cu_tokens_across_dp_cpu: torch.Tensor
|
||||
num_tokens_across_dp_cpu: torch.Tensor
|
||||
|
||||
# NOTE: local_sizes should only be set by the chunked_sizes context manager
|
||||
local_sizes: Optional[list[int]] = None
|
||||
|
||||
@staticmethod
|
||||
@ -98,6 +113,17 @@ class DPMetadata:
|
||||
dist.all_reduce(num_tokens_tensor, group=group)
|
||||
return num_tokens_tensor.cpu()
|
||||
|
||||
# Get the cumulative tokens across sequence parallel ranks.
|
||||
# In this case the input to the MoEs will be distributed w.r.t both
|
||||
# DP and TP rank.
|
||||
# When sp_size==1, this is just the cummulative num tokens across DP.
|
||||
def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
|
||||
num_tokens_across_sp_cpu = (
|
||||
(self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size)
|
||||
num_tokens_across_sp_cpu = (
|
||||
num_tokens_across_sp_cpu.repeat_interleave(sp_size))
|
||||
return torch.cumsum(num_tokens_across_sp_cpu, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def should_ubatch_across_dp(
|
||||
should_ubatch: bool, orig_num_tokens_per_ubatch: int,
|
||||
@ -147,10 +173,10 @@ class DPMetadata:
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None
|
||||
parallel_config: ParallelConfig,
|
||||
attn_metadata: Any,
|
||||
num_tokens: int,
|
||||
num_tokens_across_dp_cpu: Optional[torch.Tensor] = None
|
||||
) -> "DPMetadata":
|
||||
|
||||
assert parallel_config.data_parallel_size > 1
|
||||
@ -167,18 +193,18 @@ class DPMetadata:
|
||||
|
||||
# If num_tokens_across_dp is None, it will be computed by all_reduce
|
||||
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
|
||||
assert (num_tokens_across_dp is None or num_tokens_across_dp[dp_rank]
|
||||
== batchsize), f"{num_tokens_across_dp[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp is None:
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
assert (num_tokens_across_dp_cpu is None
|
||||
or num_tokens_across_dp_cpu[dp_rank] == batchsize
|
||||
), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
|
||||
if num_tokens_across_dp_cpu is None:
|
||||
num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp(
|
||||
batchsize, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
|
||||
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu,
|
||||
num_tokens_across_dp)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu)
|
||||
return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu)
|
||||
|
||||
@contextmanager
|
||||
def chunked_sizes(self, max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
def chunked_sizes(self, sequence_parallel_size: int,
|
||||
max_chunk_size_per_rank: int, chunk_idx: int):
|
||||
"""
|
||||
Context manager to compute and temporarily set the per-rank local token
|
||||
sizes for a specific chunk during chunked forward execution.
|
||||
@ -192,31 +218,40 @@ class DPMetadata:
|
||||
`chunk_idx`, this context manager sets `self.local_sizes` to the number
|
||||
of tokens to process in that chunk on each rank.
|
||||
|
||||
It uses cumulative sizes (`cu_tokens_across_dp_cpu`) to derive the
|
||||
number of tokens per rank, and calls `_compute_chunked_local_num_tokens`
|
||||
to determine the chunk-wise split.
|
||||
|
||||
`self.local_sizes` is only valid inside the context.
|
||||
|
||||
Args:
|
||||
sequence_parallel_size: When Attn is TP and MoE layers are EP,
|
||||
we use SP between the layers to avoid
|
||||
redundant ops. We need this value to
|
||||
compute the chunked sizes.
|
||||
max_chunk_size_per_rank: The max number of tokens each rank is
|
||||
allowed to process in this chunk.
|
||||
chunk_idx: The index of the chunk to compute sizes for.
|
||||
"""
|
||||
cu_sizes = self.cu_tokens_across_dp_cpu
|
||||
num_tokens_across_dp_cpu = [
|
||||
(cu_sizes[i] -
|
||||
cu_sizes[i - 1]).item() if i > 0 else cu_sizes[0].item()
|
||||
for i in range(len(cu_sizes))
|
||||
]
|
||||
self.local_sizes = _compute_chunked_local_num_tokens(
|
||||
num_tokens_across_dp_cpu, max_chunk_size_per_rank, chunk_idx)
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size,
|
||||
max_chunk_size_per_rank, chunk_idx)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
@contextmanager
|
||||
def sp_local_sizes(self, sequence_parallel_size: int):
|
||||
"""
|
||||
Context mamager for setting self.local_sizes. Same as self.chunked_sizes
|
||||
but without any chunking.
|
||||
"""
|
||||
self.local_sizes = _compute_sp_num_tokens(
|
||||
self.num_tokens_across_dp_cpu, sequence_parallel_size)
|
||||
try:
|
||||
yield self.local_sizes
|
||||
finally:
|
||||
self.local_sizes = None
|
||||
|
||||
def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]:
|
||||
assert self.local_sizes is not None
|
||||
return self.local_sizes
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union, get_args, overload
|
||||
|
||||
@ -983,8 +984,7 @@ class FusedMoE(CustomOp):
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
if self.is_sequence_parallel:
|
||||
self.sp_size = tp_size_
|
||||
self.sp_size = tp_size_ if is_sequence_parallel else 1
|
||||
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
@ -1966,7 +1966,8 @@ class FusedMoE(CustomOp):
|
||||
# clamp start and end
|
||||
chunk_start = min(chunk_start, num_tokens - 1)
|
||||
chunk_end = min(chunk_end, num_tokens)
|
||||
with ctx.dp_metadata.chunked_sizes(moe_dp_chunk_size_per_rank,
|
||||
with ctx.dp_metadata.chunked_sizes(self.sp_size,
|
||||
moe_dp_chunk_size_per_rank,
|
||||
chunk_idx):
|
||||
process_chunk(chunk_start,
|
||||
chunk_end,
|
||||
@ -2011,65 +2012,73 @@ class FusedMoE(CustomOp):
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
ctx = get_forward_context()
|
||||
sp_ctx = ctx.dp_metadata.sp_local_sizes(
|
||||
self.sp_size) if ctx.dp_metadata else nullcontext()
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
with sp_ctx:
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits, self.is_sequence_parallel)
|
||||
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
states = get_ep_group().combine(states)
|
||||
if shared_output is not None:
|
||||
assert not isinstance(final_hidden_states, tuple)
|
||||
assert self.shared_experts is not None
|
||||
final_hidden_states = (
|
||||
shared_output,
|
||||
final_hidden_states,
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, tuple)
|
||||
final_hidden_states, zero_expert_result = final_hidden_states
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(states)
|
||||
def reduce_output(states: torch.Tensor,
|
||||
do_combine: bool = True) -> torch.Tensor:
|
||||
if do_naive_dispatch_combine and do_combine:
|
||||
states = get_ep_group().combine(states,
|
||||
self.is_sequence_parallel)
|
||||
|
||||
return states
|
||||
if (not self.is_sequence_parallel and self.reduce_results
|
||||
and (self.tp_size > 1 or self.ep_size > 1)):
|
||||
states = self.maybe_all_reduce_tensor_model_parallel(
|
||||
states)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
return states
|
||||
|
||||
if self.shared_experts is not None:
|
||||
return (
|
||||
reduce_output(final_hidden_states[0], do_combine=False),
|
||||
reduce_output(final_hidden_states[1]),
|
||||
)
|
||||
elif self.zero_expert_num is not None and self.zero_expert_num > 0:
|
||||
assert isinstance(final_hidden_states, torch.Tensor)
|
||||
return reduce_output(final_hidden_states) + zero_expert_result
|
||||
else:
|
||||
return reduce_output(final_hidden_states)
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
@ -375,3 +376,20 @@ class PolyNorm(CustomOp):
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"""
|
||||
Layer Normalization.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias,
|
||||
self.eps).type_as(x)
|
||||
|
@ -24,6 +24,9 @@ class MLAModules:
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_b_proj: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
indexer: Optional[torch.nn.Module]
|
||||
is_sparse: bool
|
||||
topk_indices_buffer: Optional[torch.Tensor]
|
||||
|
||||
|
||||
@CustomOp.register("multi_head_latent_attention")
|
||||
@ -76,6 +79,13 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
self.kv_b_proj = mla_modules.kv_b_proj
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.o_proj = mla_modules.o_proj
|
||||
self.indexer = mla_modules.indexer
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
|
||||
if self.indexer is not None:
|
||||
assert hasattr(self.indexer, "topk_tokens")
|
||||
self.topk_tokens = self.indexer.topk_tokens
|
||||
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
@ -92,6 +102,7 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
@ -100,6 +111,7 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
indexer=self.indexer,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
@ -145,6 +157,10 @@ class MultiHeadLatentAttention(CustomOp):
|
||||
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
||||
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions,
|
||||
self.rotary_emb)
|
||||
|
||||
attn_out = self.mla_attn(
|
||||
q,
|
||||
kv_c_normed,
|
||||
|
@ -911,15 +911,15 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module,
|
||||
# On Blackwell or Hopper, if E8M0 for DeepGemm is used, we need to
|
||||
# requantize the weight and input to the specific scale
|
||||
# at the same time.
|
||||
if is_deep_gemm_e8m0_used():
|
||||
should_use_deepgemm = should_use_deepgemm_for_fp8_linear(
|
||||
layer.orig_dtype, layer.weight)
|
||||
if is_deep_gemm_e8m0_used() and should_use_deepgemm:
|
||||
block_sz = tuple(layer.weight_block_size)
|
||||
requant_weight_ue8m0_inplace(layer.weight.data,
|
||||
layer.weight_scale.data, block_sz)
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
elif (current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm_for_fp8_linear(torch.bfloat16,
|
||||
layer.weight)):
|
||||
and cutlass_block_fp8_supported and not should_use_deepgemm):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False)
|
||||
|
||||
|
@ -9,7 +9,7 @@ from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||
from transformers.models.aria.processing_aria import AriaProcessor
|
||||
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
from vllm.config import QuantizationConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -298,14 +298,12 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
|
||||
Experts (MoE) Layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AriaTextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, cache_config, quant_config, prefix)
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config, prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.mlp = AriaTextMoELayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
@ -346,8 +346,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=model_config.use_mla).page_size_bytes
|
||||
dtype=kv_cache_dtype).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
@ -401,6 +400,31 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
"exactly equal.", mamba_padding_pct)
|
||||
|
||||
|
||||
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
|
||||
"""
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
|
||||
is_v32 = hasattr(hf_config, "index_topk")
|
||||
assert is_v32
|
||||
|
||||
# For DeepSeekV3.2, we use a custom fp8 format as default (i.e.
|
||||
# "auto")
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.cache_dtype == "auto" or \
|
||||
cache_config.cache_dtype.startswith("fp8"):
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
|
||||
if cache_config.cache_dtype == "bfloat16":
|
||||
cache_config.cache_dtype = "auto"
|
||||
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
||||
|
||||
|
||||
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
@ -417,4 +441,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"MambaForCausalLM": MambaModelConfig,
|
||||
"Mamba2ForCausalLM": MambaModelConfig,
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
|
||||
}
|
||||
|
@ -53,8 +53,20 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -32,17 +32,22 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ParallelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
@ -50,20 +55,35 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv, direct_register_custom_op
|
||||
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
|
||||
from vllm.v1.attention.backends.mla.indexer import (DeepseekV32IndexerBackend,
|
||||
DeepseekV32IndexerMetadata)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
|
||||
@ -108,43 +128,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
|
||||
chunk = x.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(x, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk",
|
||||
op_func=sequence_parallel_chunk,
|
||||
fake_impl=sequence_parallel_chunk_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -166,20 +149,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
# The all_reduce at the end of attention (during o_proj) means that
|
||||
# inputs are replicated across each rank of the tensor parallel group.
|
||||
# If using expert-parallelism with DeepEP All2All ops, replicated
|
||||
# tokens results in useless duplicate computation and communication.
|
||||
#
|
||||
# In this case, ensure the input to the experts is sequence parallel
|
||||
# to avoid the excess work.
|
||||
#
|
||||
# Not needed for pplx-kernels as it can handle duplicate input tokens.
|
||||
self.is_sequence_parallel = (envs.VLLM_ALL2ALL_BACKEND
|
||||
in ("deepep_high_throughput",
|
||||
"deepep_low_latency")
|
||||
and parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
@ -278,8 +248,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
# TODO: We can replace the all_reduce at the end of attn with a
|
||||
# reduce_scatter instead of chunking here.
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = torch.ops.vllm.sequence_parallel_chunk(
|
||||
hidden_states)
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
@ -328,6 +297,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
@ -341,6 +311,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -358,6 +329,8 @@ class DeepseekV2Attention(nn.Module):
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
assert topk_indices_buffer is None, "topk_indices_buffer is not \
|
||||
supported for DeepseekV2Attention"
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
@ -470,6 +443,390 @@ class DeepseekV2Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
|
||||
def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str,
|
||||
cache_config: CacheConfig):
|
||||
super().__init__()
|
||||
self.kv_cache = [torch.tensor([])]
|
||||
self.head_dim = head_dim
|
||||
self.prefix = prefix
|
||||
self.cache_config = cache_config
|
||||
self.dtype = dtype
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
return MLAAttentionSpec( # Only has one vector instead of K + V
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def forward(self):
|
||||
...
|
||||
|
||||
def get_attn_backend(self) -> AttentionBackend:
|
||||
return DeepseekV32IndexerBackend
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def cp_gather_indexer_k_quant_cache(
|
||||
kv_cache, # [num_blocks, block_size, head_dim + 1]
|
||||
dst_value, # [cu_seq_lens[-1], head_dim]
|
||||
dst_scale, # [cu_seq_lens[-1], 4]
|
||||
block_table, # [batch_size, num_blocks]
|
||||
cu_seq_lens, # [batch_size + 1, ]
|
||||
batch_size,
|
||||
):
|
||||
num_blocks, block_size, _ = kv_cache.shape
|
||||
head_dim = dst_value.shape[-1]
|
||||
kv_cache = kv_cache.view(num_blocks, -1)
|
||||
|
||||
expected_value = []
|
||||
expected_scale = []
|
||||
for b in range(batch_size):
|
||||
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = cdiv(s, block_size)
|
||||
blocks = block_table[b, :tot]
|
||||
|
||||
value = []
|
||||
scale = []
|
||||
full_block = torch.arange(tot - 1,
|
||||
device=kv_cache.device,
|
||||
dtype=torch.int32)
|
||||
non_remaining_value = kv_cache[blocks[full_block], :block_size *
|
||||
head_dim].view(-1, head_dim)
|
||||
non_remaining_scale = kv_cache[blocks[full_block],
|
||||
block_size * head_dim:].view(-1, 4)
|
||||
|
||||
remaining = s - (tot - 1) * block_size
|
||||
|
||||
value = torch.cat([
|
||||
non_remaining_value,
|
||||
kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
|
||||
],
|
||||
dim=0)
|
||||
scale = torch.cat([
|
||||
non_remaining_scale,
|
||||
kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
|
||||
remaining * 4].view(-1, 4)
|
||||
],
|
||||
dim=0)
|
||||
|
||||
expected_value.append(value)
|
||||
expected_scale.append(scale)
|
||||
|
||||
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
||||
gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
||||
gather_value = gather_value.view(torch.float8_e4m3fn)
|
||||
gather_scale = gather_scale.view(torch.float32)
|
||||
dst_value.copy_(gather_value)
|
||||
dst_scale.copy_(gather_scale)
|
||||
|
||||
|
||||
def sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
ops.indexer_k_quant_and_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
)
|
||||
|
||||
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
k_scale = torch.empty([chunk.total_seq_lens, 1],
|
||||
device=k.device,
|
||||
dtype=torch.float32)
|
||||
cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
chunk.num_reqs,
|
||||
)
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8[chunk.token_start:chunk.token_end],
|
||||
(k_fp8, k_scale),
|
||||
weights[chunk.token_start:chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
||||
dim=-1)[1]
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||
chunk.cu_seqlen_ks)[:, None] < 0
|
||||
mask = torch.full_like(topk_indices,
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=topk_indices.device)
|
||||
mask = mask_lo & mask_hi
|
||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||
topk_indices_buffer[
|
||||
chunk.token_start:chunk.token_end, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:])
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
logits = fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
# padded query len
|
||||
current_device = padded_q_fp8_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
positions = torch.arange(max_model_len,
|
||||
device=current_device).unsqueeze(0).expand(
|
||||
batch_size * next_n, -1)
|
||||
row_indices = torch.arange(padded_num_tokens,
|
||||
device=current_device) // next_n
|
||||
next_n_offset = torch.arange(
|
||||
padded_num_tokens,
|
||||
device=padded_q_fp8_decode_tokens.device) % next_n
|
||||
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
# index_end_pos: [B * N, 1]
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
topk_indices = logits.topk(topk_tokens,
|
||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
topk_indices[topk_indices > index_end_pos] = -1
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens)
|
||||
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8)
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(
|
||||
torch.float8_e4m3fn).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sparse_attn_indexer",
|
||||
op_func=sparse_attn_indexer,
|
||||
mutates_args=["topk_indices_buffer"],
|
||||
fake_impl=sparse_attn_indexer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
q_lora_rank: int,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
|
||||
self.topk_tokens = config.index_topk
|
||||
self.n_head = config.index_n_heads # 64
|
||||
self.head_dim = config.index_head_dim # 128
|
||||
self.rope_dim = config.qk_rope_head_dim # 64
|
||||
self.q_lora_rank = q_lora_rank # 1536
|
||||
# no tensor parallel, just replicated
|
||||
self.wq_b = ReplicatedLinear(self.q_lora_rank,
|
||||
self.head_dim * self.n_head,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq_b")
|
||||
self.wk = ReplicatedLinear(hidden_size,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk")
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.weights_proj = ReplicatedLinear(hidden_size,
|
||||
self.n_head,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj")
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
self.scale_fmt = "ue8m0"
|
||||
self.quant_block_size = 128 # TODO: get from config
|
||||
self.topk_indices_buffer = topk_indices_buffer
|
||||
|
||||
# NOTE: (zyongye) we use fp8 naive cache,
|
||||
# where we store value in fp8 and scale in fp32
|
||||
# per self.quant_block_size element
|
||||
self.k_cache = DeepseekV32IndexerCache(
|
||||
head_dim=self.head_dim +
|
||||
self.head_dim // self.quant_block_size * 4,
|
||||
dtype=torch.uint8,
|
||||
prefix=f"{prefix}.k_cache",
|
||||
cache_config=cache_config)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.prefix = prefix
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
get_max_prefill_buffer_size)
|
||||
self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
|
||||
rotary_emb) -> torch.Tensor:
|
||||
q, _ = self.wq_b(qr)
|
||||
q = q.view(-1, self.n_head, self.head_dim)
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
k, _ = self.wk(hidden_states)
|
||||
k = self.k_norm(k)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(q,
|
||||
self.quant_block_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=self.scale_fmt
|
||||
is not None)
|
||||
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
weights = weights.unsqueeze(
|
||||
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return torch.ops.vllm.sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MLAAttention(nn.Module):
|
||||
"""
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
@ -481,6 +838,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
@ -495,6 +853,7 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -575,6 +934,15 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
|
||||
if self.is_v32:
|
||||
self.indexer = Indexer(vllm_config, config, hidden_size,
|
||||
q_lora_rank, quant_config, cache_config,
|
||||
topk_indices_buffer, f"{prefix}.indexer")
|
||||
else:
|
||||
self.indexer = None
|
||||
|
||||
mla_modules = MLAModules(
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
@ -588,7 +956,11 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else None,
|
||||
indexer=self.indexer,
|
||||
is_sparse=self.is_v32,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
self.mla_attn = MultiHeadLatentAttention(
|
||||
self.hidden_size,
|
||||
self.num_local_heads,
|
||||
@ -614,7 +986,10 @@ class DeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -637,6 +1012,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -652,6 +1028,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
@ -735,6 +1112,16 @@ class DeepseekV2Model(nn.Module):
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@ -747,7 +1134,8 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix),
|
||||
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
|
@ -29,10 +29,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -47,13 +46,11 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.mtp_emb_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -62,8 +59,7 @@ class ErnieMultiTokenPredictorLayer(nn.Module):
|
||||
self.mtp_linear_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.mtp_block = LlamaDecoderLayer(config, cache_config, quant_config,
|
||||
prefix)
|
||||
self.mtp_block = LlamaDecoderLayer(vllm_config, prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -102,10 +98,8 @@ class ErnieMultiTokenPredictor(nn.Module):
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
ErnieMultiTokenPredictorLayer(
|
||||
config,
|
||||
vllm_config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
|
@ -136,14 +136,16 @@ class Glm4Attention(nn.Module):
|
||||
|
||||
class Glm4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[Glm4Config] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
@ -13,7 +13,8 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -24,6 +25,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@ -132,12 +134,18 @@ class MLPBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
vllm_config: VllmConfig,
|
||||
layer_idx: int,
|
||||
quant_config: QuantizationConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
@ -155,11 +163,20 @@ class MLPBlock(torch.nn.Module):
|
||||
prefix=f"{prefix}.experts",
|
||||
apply_router_weight_on_input=False,
|
||||
has_bias=True,
|
||||
activation="swigluoai")
|
||||
activation="swigluoai",
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.shape[0]
|
||||
if self.is_sequence_parallel:
|
||||
x = sequence_parallel_chunk(x)
|
||||
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||
x = x[:num_tokens]
|
||||
return x
|
||||
|
||||
|
||||
@ -167,19 +184,20 @@ class TransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
cache_config: CacheConfig,
|
||||
quant_config: QuantizationConfig,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.attn = OAIAttention(config,
|
||||
prefix=f"{prefix}.attn",
|
||||
cache_config=cache_config)
|
||||
self.mlp = MLPBlock(config,
|
||||
self.mlp = MLPBlock(vllm_config,
|
||||
self.layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
@ -216,8 +234,6 @@ class GptOssModel(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.quant_config = vllm_config.quant_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.config.hidden_size = self.config.hidden_size
|
||||
self.embedding = VocabParallelEmbedding(
|
||||
@ -227,9 +243,7 @@ class GptOssModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
self.config.num_hidden_layers,
|
||||
lambda prefix: TransformerBlock(
|
||||
self.config,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
vllm_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
|
@ -29,12 +29,13 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.models.granitemoe import GraniteMoeConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -48,6 +49,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
@ -71,9 +73,11 @@ class GraniteMoeMoE(nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
is_sequence_parallel=False,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.is_sequence_parallel = is_sequence_parallel
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(hidden_size,
|
||||
@ -92,15 +96,27 @@ class GraniteMoeMoE(nn.Module):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts")
|
||||
prefix=f"{prefix}.experts",
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
num_tokens = orig_shape[0]
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
@ -191,12 +207,16 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteMoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -218,6 +238,7 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
|
||||
prefix=f"{prefix}.block_sparse_moe")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -255,7 +276,6 @@ class GraniteMoeModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -275,9 +295,7 @@ class GraniteMoeModel(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: GraniteMoeDecoderLayer(
|
||||
config, cache_config, quant_config=quant_config, prefix=prefix
|
||||
),
|
||||
lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
@ -68,6 +68,7 @@ class LlamaMLP(nn.Module):
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
disable_tp: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@ -75,6 +76,7 @@ class LlamaMLP(nn.Module):
|
||||
output_sizes=[intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
@ -83,6 +85,7 @@ class LlamaMLP(nn.Module):
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
disable_tp=disable_tp,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
@ -237,14 +240,16 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@ -335,7 +340,6 @@ class LlamaModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -357,10 +361,7 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: layer_type(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
|
@ -28,7 +28,8 @@ from vllm.attention import Attention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -39,6 +40,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
|
||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
|
||||
@ -59,13 +61,16 @@ class Llama4MoE(nn.Module):
|
||||
router_scores = torch.sigmoid(router_scores.float())
|
||||
return (router_scores, router_indices.to(torch.int32))
|
||||
|
||||
def __init__(self,
|
||||
config: Llama4TextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
intermediate_size_moe = config.intermediate_size
|
||||
self.router = ReplicatedLinear(config.hidden_size,
|
||||
@ -82,6 +87,7 @@ class Llama4MoE(nn.Module):
|
||||
bias=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
reduce_results=False,
|
||||
disable_tp=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
@ -96,9 +102,14 @@ class Llama4MoE(nn.Module):
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
|
||||
shared_out, routed_out = self.experts(
|
||||
@ -107,7 +118,10 @@ class Llama4MoE(nn.Module):
|
||||
)
|
||||
experts_out = routed_out + shared_out
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.is_sequence_parallel:
|
||||
experts_out = tensor_model_parallel_all_gather(experts_out, 0)
|
||||
experts_out = experts_out[:num_tokens]
|
||||
elif self.tp_size > 1:
|
||||
experts_out = self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
experts_out)
|
||||
|
||||
@ -257,15 +271,16 @@ class Llama4Attention(nn.Module):
|
||||
|
||||
class Llama4DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[Llama4TextConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -291,8 +306,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
self.layer_idx + 1) % config.interleave_moe_layer_step == 0
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
|
@ -68,9 +68,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
Llama4DecoderLayer(
|
||||
self.config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
config=self.config,
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||
|
@ -28,11 +28,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
vllm_config: VllmConfig,
|
||||
disable_input_layernorm: bool,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None,
|
||||
) -> None:
|
||||
super().__init__(config, prefix=prefix)
|
||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||
|
||||
# Skip the input_layernorm
|
||||
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
|
||||
@ -64,9 +65,10 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
self.config,
|
||||
vllm_config,
|
||||
i == 0,
|
||||
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
|
||||
config=self.config,
|
||||
) for i in range(self.config.num_hidden_layers)
|
||||
])
|
||||
self.fc = torch.nn.Linear(self.config.hidden_size * 2,
|
||||
|
@ -9,13 +9,11 @@ import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -29,17 +27,14 @@ logger = init_logger(__name__)
|
||||
|
||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None) -> None:
|
||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# override qkv
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
@ -127,9 +122,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
config=self.config,
|
||||
cache_config=current_vllm_config.cache_config,
|
||||
current_vllm_config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
||||
config=self.config,
|
||||
)
|
||||
])
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
|
@ -308,6 +308,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: FlashConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
@ -329,6 +330,7 @@ class FlashDecoderLayer(nn.Module):
|
||||
# Dual attention structure
|
||||
self.self_attn = nn.ModuleList([
|
||||
DeepseekV2MLAAttention(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -454,6 +456,7 @@ class FlashModel(nn.Module):
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: FlashDecoderLayer(
|
||||
vllm_config,
|
||||
config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
|
@ -274,6 +274,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@ -300,25 +302,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj",
|
||||
disable_tp=use_data_parallel)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
self.attn_backend = attn_backend
|
||||
self.use_upstream_fa = use_upstream_fa
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA
|
||||
}
|
||||
@ -443,6 +428,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -455,7 +442,9 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
use_upstream_fa=use_upstream_fa)
|
||||
self.mlp = Qwen2_5_VisionMLP(dim,
|
||||
mlp_hidden_dim,
|
||||
act_fn=act_fn,
|
||||
@ -627,17 +616,35 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
use_upstream_fa = False
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen2_5_VisionBlock(dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(
|
||||
vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(depth)
|
||||
Qwen2_5_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa) for layer_idx in range(depth)
|
||||
])
|
||||
self.merger = Qwen2_5_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
@ -648,12 +655,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
prefix=f"{prefix}.merger",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
@ -79,7 +79,7 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 32
|
||||
_MAX_FRAMES_PER_VIDEO = 14
|
||||
|
||||
# === Vision Inputs === #
|
||||
|
||||
@ -932,6 +932,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
_, num_image_tokens = self._get_vision_info(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
num_frames=1,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
return num_image_tokens
|
||||
@ -956,6 +957,7 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
max_image_size, _ = self._get_vision_info(
|
||||
image_width=9999999,
|
||||
image_height=9999999,
|
||||
num_frames=1,
|
||||
image_processor=None,
|
||||
)
|
||||
return max_image_size
|
||||
@ -969,10 +971,12 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
def _get_max_video_frames(self, max_tokens: int) -> int:
|
||||
def _get_max_video_frames(self,
|
||||
max_tokens: int,
|
||||
start_num_frames: int = 1) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
num_frames = 0
|
||||
num_frames = start_num_frames
|
||||
|
||||
while True:
|
||||
next_num_frames = num_frames + 1
|
||||
@ -994,12 +998,13 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
max_frames_per_video: int = _MAX_FRAMES_PER_VIDEO,
|
||||
) -> int:
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
|
||||
max_total_frames = self._get_max_video_frames(seq_len)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
max_frames_per_video)
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
|
@ -29,13 +29,13 @@ from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Qwen3MoeConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
|
||||
@ -101,12 +102,15 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3MoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
@ -114,6 +118,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
@ -122,7 +128,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
@ -144,7 +150,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
@ -156,14 +163,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
assert hidden_states.dim(
|
||||
) <= 2, "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
|
||||
is_input_1d = hidden_states.dim() == 1
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
|
||||
# return to 1d if input is 1d
|
||||
return final_hidden_states.squeeze(0) if is_input_1d else \
|
||||
final_hidden_states
|
||||
@ -275,15 +290,13 @@ class Qwen3MoeAttention(nn.Module):
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3MoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@ -315,10 +328,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
if (layer_idx not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and
|
||||
(layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb)
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
else:
|
||||
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
@ -361,11 +372,9 @@ class Qwen3MoeModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config.get_text_config()
|
||||
cache_config = vllm_config.cache_config
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
@ -379,11 +388,8 @@ class Qwen3MoeModel(nn.Module):
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Qwen3MoeDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb),
|
||||
lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config,
|
||||
prefix=prefix),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -580,7 +586,7 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
config = vllm_config.model_config.hf_text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
|
||||
VllmConfig, get_current_vllm_config)
|
||||
from vllm.distributed import (divide, get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
|
||||
from vllm.model_executor.models.utils import sequence_parallel_chunk
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts = config.num_experts
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
self.enable_eplb = parallel_config.enable_eplb
|
||||
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.num_experts,
|
||||
@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_dim = hidden_states.shape[-1]
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
shared_output = None
|
||||
if self.shared_expert is not None:
|
||||
shared_output = self.shared_expert(hidden_states)
|
||||
@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
final_hidden_states = tensor_model_parallel_all_gather(
|
||||
final_hidden_states, 0)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
|
||||
final_hidden_states)
|
||||
|
||||
@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3NextConfig,
|
||||
vllm_config: VllmConfig,
|
||||
layer_type: str,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
speculative_config: Optional[SpeculativeConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.layer_type = layer_type
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
config.num_experts > 0 and
|
||||
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen3NextSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = Qwen3NextMLP(
|
||||
@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
self.config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
self.ffn_layer_scale = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
1,
|
||||
1,
|
||||
self.config.hidden_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
), )
|
||||
|
||||
@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config: Qwen3NextConfig = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
lora_config = vllm_config.lora_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
enable_eplb = parallel_config.enable_eplb
|
||||
eplb_config = parallel_config.eplb_config
|
||||
self.num_redundant_experts = eplb_config.num_redundant_experts
|
||||
|
||||
@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
|
||||
|
||||
def get_layer(prefix: str):
|
||||
return Qwen3NextDecoderLayer(
|
||||
config,
|
||||
vllm_config,
|
||||
layer_type=config.layer_types[extract_layer_index(prefix)],
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
speculative_config=speculative_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
|
@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
config: Qwen3NextConfig = model_config.hf_config
|
||||
@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
|
||||
|
||||
self.layers = torch.nn.ModuleList(
|
||||
Qwen3NextDecoderLayer(
|
||||
config,
|
||||
vllm_config,
|
||||
layer_type="full_attention",
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f'{prefix}.layers.{idx}',
|
||||
) for idx in range(self.num_mtp_layers))
|
||||
|
||||
|
@ -33,11 +33,14 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
from transformers.models.qwen2_vl import Qwen2VLImageProcessorFast
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
||||
smart_resize as image_smart_resize)
|
||||
from transformers.models.qwen3_vl import (Qwen3VLProcessor,
|
||||
Qwen3VLVideoProcessor)
|
||||
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
|
||||
Qwen3VLConfig, Qwen3VLVisionConfig)
|
||||
from transformers.models.qwen3_vl.video_processing_qwen3_vl import (
|
||||
smart_resize as video_smart_resize)
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
@ -84,6 +87,9 @@ from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Official recommended max pixels is 24576 * 32 * 32
|
||||
_MAX_FRAMES_PER_VIDEO = 24576
|
||||
|
||||
|
||||
class Qwen3_VisionPatchEmbed(nn.Module):
|
||||
|
||||
@ -158,6 +164,8 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@ -170,7 +178,9 @@ class Qwen3_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel)
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=attn_backend,
|
||||
use_upstream_fa=use_upstream_fa)
|
||||
self.mlp = Qwen3_VisionMLP(dim,
|
||||
mlp_hidden_dim,
|
||||
act_fn=act_fn,
|
||||
@ -287,19 +297,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
head_dim = self.hidden_size // self.num_heads
|
||||
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
self.merger = Qwen3_VisionPatchMerger(
|
||||
d_model=vision_config.out_hidden_size,
|
||||
context_dim=self.hidden_size,
|
||||
@ -325,10 +322,34 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now.")
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Qwen3_VisionBlock(
|
||||
dim=self.hidden_size,
|
||||
num_heads=self.num_heads,
|
||||
mlp_hidden_dim=vision_config.intermediate_size,
|
||||
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend=self.attn_backend,
|
||||
use_upstream_fa=use_upstream_fa)
|
||||
for layer_idx in range(vision_config.depth)
|
||||
])
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@ -569,11 +590,16 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
image_height: int,
|
||||
num_frames: int = 2,
|
||||
do_resize: bool = True,
|
||||
image_processor: Optional[Qwen2VLImageProcessorFast],
|
||||
image_processor: Optional[Union[Qwen2VLImageProcessorFast,
|
||||
Qwen3VLVideoProcessor]],
|
||||
) -> tuple[ImageSize, int]:
|
||||
if image_processor is None:
|
||||
if image_processor is None and num_frames > 1:
|
||||
image_processor = self.get_video_processor()
|
||||
elif image_processor is None:
|
||||
image_processor = self.get_image_processor()
|
||||
|
||||
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
|
||||
|
||||
hf_config = self.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
patch_size = vision_config.patch_size
|
||||
@ -581,12 +607,22 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
temporal_patch_size = vision_config.temporal_patch_size
|
||||
|
||||
if do_resize:
|
||||
if is_video:
|
||||
smart_resize = video_smart_resize
|
||||
extra_kwargs = {
|
||||
"num_frames": num_frames,
|
||||
"temporal_factor": temporal_patch_size
|
||||
}
|
||||
else:
|
||||
smart_resize = image_smart_resize
|
||||
extra_kwargs = {}
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=image_height,
|
||||
width=image_width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=image_processor.size["shortest_edge"],
|
||||
max_pixels=image_processor.size["longest_edge"],
|
||||
**extra_kwargs,
|
||||
)
|
||||
preprocessed_size = ImageSize(width=resized_width,
|
||||
height=resized_height)
|
||||
@ -605,6 +641,39 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
|
||||
return preprocessed_size, num_vision_tokens
|
||||
|
||||
def _get_max_video_frames(self,
|
||||
max_tokens: int,
|
||||
start_num_frames: int = 2) -> int:
|
||||
return super()._get_max_video_frames(max_tokens,
|
||||
start_num_frames=start_num_frames)
|
||||
|
||||
def get_num_frames_with_most_features(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
return super().get_num_frames_with_most_features(
|
||||
seq_len, mm_counts, max_frames_per_video=_MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
def get_max_video_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
video_soft_tokens = self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts),
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
# NOTE: By default in Qwen3-VL, one video token is converted to
|
||||
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
|
||||
formatted_video_soft_tokens = video_soft_tokens * 12.5
|
||||
return int(formatted_video_soft_tokens)
|
||||
|
||||
def _calculate_timestamps(self, indices: list[int] | torch.Tensor,
|
||||
video_fps: float, merge_size: int):
|
||||
if not isinstance(indices, list):
|
||||
@ -674,6 +743,12 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
|
||||
self.info.get_image_size_with_most_features())
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts)
|
||||
target_video_size, _ = self.info._get_vision_info(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=target_num_frames,
|
||||
image_processor=self.info.get_video_processor(),
|
||||
)
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
@ -681,8 +756,8 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
|
||||
num_images=num_images),
|
||||
"video":
|
||||
self._get_dummy_videos(
|
||||
width=target_width,
|
||||
height=target_height,
|
||||
width=target_video_size.width,
|
||||
height=target_video_size.height,
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
),
|
||||
@ -1051,14 +1126,17 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if not multimodal_config.get_limit_per_prompt("image") and \
|
||||
not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
@ -1074,11 +1152,15 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config.vision_config.deepstack_visual_indexes
|
||||
) if self.use_deepstack else 0
|
||||
# register buffer for deepstack
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
] if self.use_deepstack else None
|
||||
if self.use_deepstack and self.visual is not None:
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
]
|
||||
else:
|
||||
self.deepstack_input_embeds = None
|
||||
self.visual_dim = config.vision_config.out_hidden_size
|
||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||
|
||||
@ -1513,7 +1595,11 @@ class Qwen3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
skip_prefixes = []
|
||||
if self.visual is None:
|
||||
skip_prefixes.extend(["visual."])
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
|
@ -212,6 +212,8 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
if is_fused_expert:
|
||||
loaded_weight = loaded_weight.transpose(-1,
|
||||
-2) # no bias
|
||||
@ -230,8 +232,6 @@ class Qwen3MoeLLMModel(Qwen3MoeModel):
|
||||
name_mapped, params_dict, loaded_weight,
|
||||
shard_id, num_experts)
|
||||
else:
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
# Skip loading extra parameters for GPTQ/modelopt models
|
||||
if name_mapped.endswith(
|
||||
ignore_suffixes
|
||||
@ -319,13 +319,17 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
if not multimodal_config.get_limit_per_prompt("image") and \
|
||||
not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
@ -341,10 +345,14 @@ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
|
||||
config.vision_config.deepstack_visual_indexes
|
||||
) if self.use_deepstack else 0
|
||||
# register buffer for deepstack
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
] if self.use_deepstack else None
|
||||
if self.use_deepstack and self.visual is not None:
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
config.text_config.hidden_size)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
]
|
||||
else:
|
||||
self.deepstack_input_embeds = None
|
||||
self.visual_dim = config.vision_config.out_hidden_size
|
||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||
|
@ -70,6 +70,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
|
||||
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
|
||||
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
|
||||
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
|
||||
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
|
||||
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
|
||||
|
@ -13,11 +13,14 @@ from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
from vllm.utils import (cdiv, direct_register_custom_op,
|
||||
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
is_uva_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
|
||||
return hf_config.hidden_size
|
||||
text_config = hf_config.get_text_config()
|
||||
return text_config.hidden_size
|
||||
|
||||
|
||||
# Chunk x along the num_tokens axis for sequence parallelism
|
||||
# NOTE: This is wrapped in a torch custom op to work around the following issue:
|
||||
# The output tensor can have a sequence length 0 at small input sequence lengths
|
||||
# even though we explicitly pad to avoid this.
|
||||
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
# all_gather needs the sequence length to be divisible by tp_size
|
||||
seq_len = x.size(0)
|
||||
remainder = seq_len % tp_size
|
||||
if remainder != 0:
|
||||
pad_len = tp_size - remainder
|
||||
y = nn.functional.pad(x, (0, 0, 0, pad_len))
|
||||
else:
|
||||
y = x
|
||||
|
||||
chunk = y.shape[0] // tp_size
|
||||
start = tp_rank * chunk
|
||||
return torch.narrow(y, 0, start, chunk)
|
||||
|
||||
|
||||
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
seq_len = cdiv(x.size(0), tp_size)
|
||||
shape = list(x.shape)
|
||||
shape[0] = seq_len
|
||||
out = torch.empty(shape, dtype=x.dtype, device=x.device)
|
||||
return out
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sequence_parallel_chunk_impl",
|
||||
op_func=sequence_parallel_chunk_impl,
|
||||
fake_impl=sequence_parallel_chunk_impl_fake,
|
||||
tags=(torch.Tag.needs_fixed_stride_order, ),
|
||||
)
|
||||
|
@ -50,6 +50,7 @@ class MediaConnector:
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
allowed_media_domains: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -82,6 +83,9 @@ class MediaConnector:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
if allowed_media_domains is None:
|
||||
allowed_media_domains = []
|
||||
self.allowed_media_domains = allowed_media_domains
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
@ -115,6 +119,14 @@ class MediaConnector:
|
||||
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
||||
if self.allowed_media_domains and url_spec.hostname not in \
|
||||
self.allowed_media_domains:
|
||||
raise ValueError(
|
||||
f"The URL must be from one of the allowed domains: "
|
||||
f"{self.allowed_media_domains}. Input URL domain: "
|
||||
f"{url_spec.hostname}")
|
||||
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
@ -125,8 +137,14 @@ class MediaConnector:
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(url, timeout=fetch_timeout)
|
||||
data = connection.get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
@ -150,8 +168,14 @@ class MediaConnector:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
self._assert_url_in_allowed_media_domains(url_spec)
|
||||
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||
data = await connection.async_get_bytes(
|
||||
url,
|
||||
timeout=fetch_timeout,
|
||||
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||
)
|
||||
future = loop.run_in_executor(global_thread_pool,
|
||||
media_io.load_bytes, data)
|
||||
return await future
|
||||
|
@ -93,11 +93,14 @@ class CpuPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on CPU.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
|
@ -129,6 +129,8 @@ class CudaPlatformBase(Platform):
|
||||
# TODO(lucas): handle this more gracefully
|
||||
# Note: model_config may be None during testing
|
||||
if model_config is not None and model_config.use_mla:
|
||||
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
# If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA,
|
||||
# then we default to FlashMLA backend for non-blackwell GPUs,
|
||||
# else we default to CutlassMLA. For each case, we force the
|
||||
@ -175,6 +177,12 @@ class CudaPlatformBase(Platform):
|
||||
"Forcing kv cache block size to 64 for FlashInferMLA "
|
||||
"backend.")
|
||||
|
||||
# TODO(Chen): remove this hacky code
|
||||
if use_sparse and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLASparse "
|
||||
"backend.")
|
||||
# lazy import to avoid circular import
|
||||
from vllm.config import CUDAGraphMode
|
||||
|
||||
@ -205,6 +213,12 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int,
|
||||
dtype: torch.dtype) -> _Backend:
|
||||
|
||||
# For Blackwell GPUs, force TORCH_SDPA for now.
|
||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
||||
if cls.has_device_capability(100):
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
return _Backend.XFORMERS
|
||||
|
||||
@ -225,7 +239,7 @@ class CudaPlatformBase(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
@ -235,6 +249,11 @@ class CudaPlatformBase(Platform):
|
||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
|
||||
if use_sparse:
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla.flashmla_sparse."
|
||||
"FlashMLASparseBackend")
|
||||
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None and cls.is_device_capability(100)
|
||||
and block_size == 128)
|
||||
|
@ -194,7 +194,7 @@ class Platform:
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink: bool) -> str:
|
||||
has_sink: bool, use_sparse: bool) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
|
@ -195,7 +195,10 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla,
|
||||
has_sink) -> str:
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on ROCm.")
|
||||
if use_mla:
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
|
@ -49,7 +49,10 @@ class TpuPlatform(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool, use_mla: bool,
|
||||
has_sink) -> str:
|
||||
has_sink, use_sparse) -> str:
|
||||
if use_sparse:
|
||||
raise NotImplementedError(
|
||||
"Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user