[BugFix] Fix IMA FlashMLA full cuda-graph and DP + Update FlashMLA (#21691)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: yewentao256 <zhyanwentao@126.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Lucas Wilkinson
2025-08-08 19:09:42 -04:00
committed by GitHub
parent fe6d8257a1
commit cd9b9de1fb
3 changed files with 40 additions and 25 deletions

View File

@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@ -37,9 +37,9 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc/cutlass/include

View File

@ -91,7 +91,6 @@ def flash_mla_with_kvcache(
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,

View File

@ -70,6 +70,22 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
if self.compilation_config.full_cuda_graph:
self.cg_buf_tile_scheduler_metadata = torch.zeros(
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
# TileSchedulerMetaDataSize = 8
(num_sms, 8),
device=self.device,
dtype=torch.int32,
)
self.cg_buf_num_splits = torch.empty(
(vllm_config.scheduler_config.max_num_seqs + 1),
device=self.device,
dtype=torch.int32)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
@ -80,28 +96,28 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
)
if self.compilation_config.full_cuda_graph:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
assert self.cg_buf_tile_scheduler_metadata is not None
assert self.cg_buf_num_splits is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
sm_parts = tile_scheduler_metadata.size(0)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
tile_scheduler_metadata_view = \
self.cg_buf_tile_scheduler_metadata[:sm_parts]
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
tile_scheduler_metadata = tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
num_splits = num_splits_view
return FlashMLADecodeMetadata(
block_table=block_table_tensor,