mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ROCm] Disable chunked prefill/prefix caching when running MLA on non-cuda platforms (#13844)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@ -232,6 +232,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
|
||||
|
||||
try:
|
||||
@ -1371,18 +1372,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=has_context,
|
||||
)
|
||||
if has_context:
|
||||
if not current_platform.is_cuda():
|
||||
raise NotImplementedError(
|
||||
"Chunked Prefill for MLA is not currently supported on"
|
||||
"non-cuda platforms")
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
else:
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if has_context:
|
||||
suffix_output, suffix_lse = output
|
||||
|
@ -3422,6 +3422,20 @@ class VllmConfig:
|
||||
"Disabling `torch.compile`.")
|
||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||
|
||||
if self.model_config and self.model_config.use_mla and \
|
||||
not current_platform.is_cuda():
|
||||
logger.info(
|
||||
"MLA is enabled on a non-cuda platform; forcing chunked "
|
||||
"prefill and prefix caching to be disabled.")
|
||||
self.scheduler_config.enable_chunked_prefill = False
|
||||
self.scheduler_config.chunked_prefill_enabled = False
|
||||
self.scheduler_config.max_num_batched_tokens = max(
|
||||
self.scheduler_config.max_model_len,
|
||||
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
if not self.instance_id:
|
||||
|
Reference in New Issue
Block a user