mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix][V1][ROCm] Triton MLA uses V0 backend on V1 engine (#19067)
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
This commit is contained in:
@ -106,10 +106,8 @@ def test_env(
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
if use_v1 and name != "TRITON_MLA":
|
||||
assert backend.get_name() == f"{name}_VLLM_V1"
|
||||
else:
|
||||
assert backend.get_name() == name
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
|
@ -35,7 +35,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
@ -43,7 +44,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||
False, True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
assert (backend.get_name() == "TRITON_MLA"
|
||||
or backend.get_name() == "TRITON_MLA_VLLM_V1")
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
|
@ -186,8 +186,14 @@ class RocmPlatform(Platform):
|
||||
|
||||
if selected_backend == _Backend.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
|
||||
if use_v1:
|
||||
logger.info_once(
|
||||
"Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
else:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend" # noqa: E501
|
||||
else:
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
|
@ -640,7 +640,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
|
||||
# Handle the differences between the flash_attn_varlen from flash_attn
|
||||
# and the one from vllm_flash_attn. The former is used on RoCM and the
|
||||
@ -672,11 +671,17 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
maybe_padded_v = torch.nn.functional.pad(
|
||||
v, [0, q.shape[-1] - v.shape[-1]], value=0)
|
||||
|
||||
if is_vllm_fa:
|
||||
kwargs["return_softmax_lse"] = return_softmax_lse
|
||||
else:
|
||||
# ROCm leverages the upstream flash_attn, which takes a parameter
|
||||
# called "return_attn_probs" instead of return_softmax_lse
|
||||
kwargs["return_attn_probs"] = return_softmax_lse
|
||||
|
||||
attn_out = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=maybe_padded_v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
@ -5,10 +5,14 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
@ -68,6 +72,59 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
|
||||
self.triton_fa_func = triton_attention if HAS_TRITON else None
|
||||
|
||||
def _flash_attn_varlen_diff_headdims_rocm(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
assert self.triton_fa_func is not None
|
||||
|
||||
# Triton Attention requires a padded V
|
||||
padded_v = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
# The output of triton_attention is a tuple of
|
||||
# [output_tensor, encoded_softmax] where encoded_softmax is always None
|
||||
output_tensor, _ = self.triton_fa_func(
|
||||
q,
|
||||
k,
|
||||
padded_v,
|
||||
None, # output
|
||||
kwargs["cu_seqlens_q"],
|
||||
kwargs["cu_seqlens_k"],
|
||||
kwargs["max_seqlen_q"],
|
||||
kwargs["max_seqlen_k"],
|
||||
kwargs["causal"],
|
||||
softmax_scale,
|
||||
None, # bias
|
||||
)
|
||||
|
||||
return output_tensor
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=False,
|
||||
softmax_scale=None,
|
||||
**kwargs):
|
||||
if current_platform.is_rocm() \
|
||||
and self.use_triton_flash_attn \
|
||||
and not return_softmax_lse:
|
||||
return self._flash_attn_varlen_diff_headdims_rocm(
|
||||
q, k, v, softmax_scale=softmax_scale, **kwargs)
|
||||
else:
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user