[BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067)

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
This commit is contained in:
TY-AMD
2025-07-01 16:12:19 +08:00
committed by GitHub
parent b1c1fe35a5
commit 96453cfa83
5 changed files with 78 additions and 10 deletions

View File

@ -106,10 +106,8 @@ def test_env(
block_size,
False,
use_mla=use_mla)
if use_v1 and name != "TRITON_MLA":
assert backend.get_name() == f"{name}_VLLM_V1"
else:
assert backend.get_name() == name
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,

View File

@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
assert backend.get_name() == "TRITON_MLA"
assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
# If attention backend is None
# If use_mla is true
@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
m.setenv(STR_BACKEND_ENV_VAR, None)
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
False, True)
assert backend.get_name() == "TRITON_MLA"
assert (backend.get_name() == "TRITON_MLA"
or backend.get_name() == "TRITON_MLA_VLLM_V1")
# change the attention backend to AITER MLA
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")

View File

@ -186,8 +186,14 @@ class RocmPlatform(Platform):
if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
if use_v1:
logger.info_once(
"Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
else:
raise ValueError(
f" The selected backend, {selected_backend.name},"

View File

@ -640,7 +640,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj
self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
@ -672,11 +671,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0)
if is_vllm_fa:
kwargs["return_softmax_lse"] = return_softmax_lse
else:
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)

View File

@ -5,10 +5,14 @@ from typing import Any, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.attention.ops.triton_flash_attention import triton_attention
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
@ -68,6 +72,59 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
self.triton_fa_func = triton_attention if HAS_TRITON else None
def _flash_attn_varlen_diff_headdims_rocm(self,
q,
k,
v,
softmax_scale=None,
**kwargs):
assert self.triton_fa_func is not None
# Triton Attention requires a padded V
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
# The output of triton_attention is a tuple of
# [output_tensor, encoded_softmax] where encoded_softmax is always None
output_tensor, _ = self.triton_fa_func(
q,
k,
padded_v,
None, # output
kwargs["cu_seqlens_q"],
kwargs["cu_seqlens_k"],
kwargs["max_seqlen_q"],
kwargs["max_seqlen_k"],
kwargs["causal"],
softmax_scale,
None, # bias
)
return output_tensor
def _flash_attn_varlen_diff_headdims(self,
q,
k,
v,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
if current_platform.is_rocm() \
and self.use_triton_flash_attn \
and not return_softmax_lse:
return self._flash_attn_varlen_diff_headdims_rocm(
q, k, v, softmax_scale=softmax_scale, **kwargs)
else:
return super()._flash_attn_varlen_diff_headdims(
q,
k,
v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs)
def _forward_decode(
self,
q_nope: torch.Tensor,