[V1] Enable Triton(ROCm) Attention backend for Nvidia GPUs (#14071)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Isotr0py
2025-03-21 11:14:19 +08:00
committed by GitHub
parent b15fd2be2a
commit f8a08cb90d
5 changed files with 23 additions and 16 deletions

View File

@ -1588,7 +1588,7 @@ class EngineArgs:
# No FlashInfer or XFormers so far. # No FlashInfer or XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
"TRITON_MLA", "FLASHMLA" "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):

View File

@ -213,9 +213,14 @@ class CudaPlatformBase(Platform):
return ("vllm.attention.backends." return ("vllm.attention.backends."
"flashmla.FlashMLABackend") "flashmla.FlashMLABackend")
if use_v1: if use_v1:
logger.info_once("Using Flash Attention backend on V1 engine.") if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
return ("vllm.v1.attention.backends.flash_attn." logger.info_once("Using Triton backend on V1 engine.")
"FlashAttentionBackend") return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if cls.has_device_capability(80):
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends."
"flash_attn.FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER: if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.") logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend" return "vllm.attention.backends.flashinfer.FlashInferBackend"

View File

@ -29,6 +29,7 @@ def in_wsl() -> bool:
class _Backend(enum.Enum): class _Backend(enum.Enum):
FLASH_ATTN = enum.auto() FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto() FLASH_ATTN_VLLM_V1 = enum.auto()
TRITON_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto() XFORMERS = enum.auto()
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()

View File

@ -120,8 +120,9 @@ class RocmPlatform(Platform):
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
logger.info("Using ROCm Attention backend on V1 engine.") logger.info("Using Triton Attention backend on V1 engine.")
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" return ("vllm.v1.attention.backends."
"triton_attn.TritonAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90): if not cls.has_device_capability(90):
# not Instinct series GPUs. # not Instinct series GPUs.

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention on rocm""" """Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -16,7 +16,7 @@ from vllm.v1.attention.backends.flash_attn import (
logger = init_logger(__name__) logger = init_logger(__name__)
class ROCmAttentionBackend(AttentionBackend): class TritonAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
@ -26,11 +26,11 @@ class ROCmAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "ROCM_ATTN_VLLM_V1" return "TRITON_ATTN_VLLM_V1"
@staticmethod @staticmethod
def get_impl_cls() -> type["ROCmAttentionImpl"]: def get_impl_cls() -> type["TritonAttentionImpl"]:
return ROCmAttentionImpl return TritonAttentionImpl
@staticmethod @staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]: def get_metadata_cls() -> type["AttentionMetadata"]:
@ -56,7 +56,7 @@ class ROCmAttentionBackend(AttentionBackend):
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
class ROCmAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
def __init__( def __init__(
self, self,
@ -73,7 +73,7 @@ class ROCmAttentionImpl(AttentionImpl):
) -> None: ) -> None:
if blocksparse_params is not None: if blocksparse_params is not None:
raise ValueError( raise ValueError(
"ROCmAttention does not support block-sparse attention.") "TritonAttention does not support block-sparse attention.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
@ -90,17 +90,17 @@ class ROCmAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes() support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes: if head_size not in support_head_sizes:
raise ValueError( raise ValueError(
f"Head size {head_size} is not supported by ROCmAttention. " f"Head size {head_size} is not supported by TritonAttention. "
f"Supported head sizes are: {support_head_sizes}.") f"Supported head sizes are: {support_head_sizes}.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
"ROCmAttentionImpl") "TritonAttentionImpl")
def forward( def forward(
self, self,