[ROCm] Disable chunked prefill/prefix caching when running MLA on non-cuda platforms (#13844)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore
2025-02-25 22:56:58 -08:00
committed by GitHub
parent e656f638de
commit 1d35662e6d
2 changed files with 44 additions and 12 deletions

View File

@ -232,6 +232,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
try:
@ -1371,18 +1372,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
if has_context:
if not current_platform.is_cuda():
raise NotImplementedError(
"Chunked Prefill for MLA is not currently supported on"
"non-cuda platforms")
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=True,
)
else:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
if has_context:
suffix_output, suffix_lse = output

View File

@ -3422,6 +3422,20 @@ class VllmConfig:
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.model_config and self.model_config.use_mla and \
not current_platform.is_cuda():
logger.info(
"MLA is enabled on a non-cuda platform; forcing chunked "
"prefill and prefix caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
current_platform.check_and_update_config(self)
if not self.instance_id: