From f8a08cb90dc0b5b45663cd2605d0c98c77efe009 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 21 Mar 2025 11:14:19 +0800 Subject: [PATCH] [V1] Enable Triton(ROCm) Attention backend for Nvidia GPUs (#14071) Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Woosuk Kwon --- vllm/engine/arg_utils.py | 2 +- vllm/platforms/cuda.py | 11 +++++++--- vllm/platforms/interface.py | 1 + vllm/platforms/rocm.py | 5 +++-- .../backends/{rocm_attn.py => triton_attn.py} | 20 +++++++++---------- 5 files changed, 23 insertions(+), 16 deletions(-) rename vllm/v1/attention/backends/{rocm_attn.py => triton_attn.py} (91%) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 986d1b4074..edfa748b82 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1588,7 +1588,7 @@ class EngineArgs: # No FlashInfer or XFormers so far. V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", - "TRITON_MLA", "FLASHMLA" + "TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index dd2a9cb616..38d8fffd63 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -213,9 +213,14 @@ class CudaPlatformBase(Platform): return ("vllm.attention.backends." "flashmla.FlashMLABackend") if use_v1: - logger.info_once("Using Flash Attention backend on V1 engine.") - return ("vllm.v1.attention.backends.flash_attn." - "FlashAttentionBackend") + if selected_backend == _Backend.TRITON_ATTN_VLLM_V1: + logger.info_once("Using Triton backend on V1 engine.") + return ("vllm.v1.attention.backends." + "triton_attn.TritonAttentionBackend") + if cls.has_device_capability(80): + logger.info_once("Using Flash Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "flash_attn.FlashAttentionBackend") if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index c7152d0bfb..d3bffaf4d6 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -29,6 +29,7 @@ def in_wsl() -> bool: class _Backend(enum.Enum): FLASH_ATTN = enum.auto() FLASH_ATTN_VLLM_V1 = enum.auto() + TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 75f287b568..ee708f5961 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -120,8 +120,9 @@ class RocmPlatform(Platform): selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if envs.VLLM_USE_V1: - logger.info("Using ROCm Attention backend on V1 engine.") - return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" + logger.info("Using Triton Attention backend on V1 engine.") + return ("vllm.v1.attention.backends." + "triton_attn.TritonAttentionBackend") if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): # not Instinct series GPUs. diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/triton_attn.py similarity index 91% rename from vllm/v1/attention/backends/rocm_attn.py rename to vllm/v1/attention/backends/triton_attn.py index 640c3b3d4f..f11f2b6271 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""Attention layer with PagedAttention on rocm""" +"""Attention layer with PagedAttention and Triton prefix prefill.""" from typing import Any, Optional import torch @@ -16,7 +16,7 @@ from vllm.v1.attention.backends.flash_attn import ( logger = init_logger(__name__) -class ROCmAttentionBackend(AttentionBackend): +class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -26,11 +26,11 @@ class ROCmAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ROCM_ATTN_VLLM_V1" + return "TRITON_ATTN_VLLM_V1" @staticmethod - def get_impl_cls() -> type["ROCmAttentionImpl"]: - return ROCmAttentionImpl + def get_impl_cls() -> type["TritonAttentionImpl"]: + return TritonAttentionImpl @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: @@ -56,7 +56,7 @@ class ROCmAttentionBackend(AttentionBackend): return FlashAttentionMetadataBuilder -class ROCmAttentionImpl(AttentionImpl): +class TritonAttentionImpl(AttentionImpl): def __init__( self, @@ -73,7 +73,7 @@ class ROCmAttentionImpl(AttentionImpl): ) -> None: if blocksparse_params is not None: raise ValueError( - "ROCmAttention does not support block-sparse attention.") + "TritonAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -90,17 +90,17 @@ class ROCmAttentionImpl(AttentionImpl): assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes() + support_head_sizes = TritonAttentionBackend.get_supported_head_sizes() if head_size not in support_head_sizes: raise ValueError( - f"Head size {head_size} is not supported by ROCmAttention. " + f"Head size {head_size} is not supported by TritonAttention. " f"Supported head sizes are: {support_head_sizes}.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " "are not implemented for " - "ROCmAttentionImpl") + "TritonAttentionImpl") def forward( self,