mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[V1] Enable Triton(ROCm) Attention backend for Nvidia GPUs (#14071)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -1588,7 +1588,7 @@ class EngineArgs:
|
|||||||
# No FlashInfer or XFormers so far.
|
# No FlashInfer or XFormers so far.
|
||||||
V1_BACKENDS = [
|
V1_BACKENDS = [
|
||||||
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
|
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
|
||||||
"TRITON_MLA", "FLASHMLA"
|
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
|
||||||
]
|
]
|
||||||
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
if (envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||||
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
|
||||||
|
|||||||
@ -213,9 +213,14 @@ class CudaPlatformBase(Platform):
|
|||||||
return ("vllm.attention.backends."
|
return ("vllm.attention.backends."
|
||||||
"flashmla.FlashMLABackend")
|
"flashmla.FlashMLABackend")
|
||||||
if use_v1:
|
if use_v1:
|
||||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
|
||||||
return ("vllm.v1.attention.backends.flash_attn."
|
logger.info_once("Using Triton backend on V1 engine.")
|
||||||
"FlashAttentionBackend")
|
return ("vllm.v1.attention.backends."
|
||||||
|
"triton_attn.TritonAttentionBackend")
|
||||||
|
if cls.has_device_capability(80):
|
||||||
|
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||||
|
return ("vllm.v1.attention.backends."
|
||||||
|
"flash_attn.FlashAttentionBackend")
|
||||||
if selected_backend == _Backend.FLASHINFER:
|
if selected_backend == _Backend.FLASHINFER:
|
||||||
logger.info("Using FlashInfer backend.")
|
logger.info("Using FlashInfer backend.")
|
||||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||||
|
|||||||
@ -29,6 +29,7 @@ def in_wsl() -> bool:
|
|||||||
class _Backend(enum.Enum):
|
class _Backend(enum.Enum):
|
||||||
FLASH_ATTN = enum.auto()
|
FLASH_ATTN = enum.auto()
|
||||||
FLASH_ATTN_VLLM_V1 = enum.auto()
|
FLASH_ATTN_VLLM_V1 = enum.auto()
|
||||||
|
TRITON_ATTN_VLLM_V1 = enum.auto()
|
||||||
XFORMERS = enum.auto()
|
XFORMERS = enum.auto()
|
||||||
ROCM_FLASH = enum.auto()
|
ROCM_FLASH = enum.auto()
|
||||||
TORCH_SDPA = enum.auto()
|
TORCH_SDPA = enum.auto()
|
||||||
|
|||||||
@ -120,8 +120,9 @@ class RocmPlatform(Platform):
|
|||||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||||
== _Backend.FLASH_ATTN else selected_backend)
|
== _Backend.FLASH_ATTN else selected_backend)
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
logger.info("Using ROCm Attention backend on V1 engine.")
|
logger.info("Using Triton Attention backend on V1 engine.")
|
||||||
return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
|
return ("vllm.v1.attention.backends."
|
||||||
|
"triton_attn.TritonAttentionBackend")
|
||||||
if selected_backend == _Backend.ROCM_FLASH:
|
if selected_backend == _Backend.ROCM_FLASH:
|
||||||
if not cls.has_device_capability(90):
|
if not cls.has_device_capability(90):
|
||||||
# not Instinct series GPUs.
|
# not Instinct series GPUs.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
"""Attention layer with PagedAttention on rocm"""
|
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -16,7 +16,7 @@ from vllm.v1.attention.backends.flash_attn import (
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ROCmAttentionBackend(AttentionBackend):
|
class TritonAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@ -26,11 +26,11 @@ class ROCmAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "ROCM_ATTN_VLLM_V1"
|
return "TRITON_ATTN_VLLM_V1"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["ROCmAttentionImpl"]:
|
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||||
return ROCmAttentionImpl
|
return TritonAttentionImpl
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
@ -56,7 +56,7 @@ class ROCmAttentionBackend(AttentionBackend):
|
|||||||
return FlashAttentionMetadataBuilder
|
return FlashAttentionMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
class ROCmAttentionImpl(AttentionImpl):
|
class TritonAttentionImpl(AttentionImpl):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -73,7 +73,7 @@ class ROCmAttentionImpl(AttentionImpl):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ROCmAttention does not support block-sparse attention.")
|
"TritonAttention does not support block-sparse attention.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
@ -90,17 +90,17 @@ class ROCmAttentionImpl(AttentionImpl):
|
|||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes()
|
support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
|
||||||
if head_size not in support_head_sizes:
|
if head_size not in support_head_sizes:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Head size {head_size} is not supported by ROCmAttention. "
|
f"Head size {head_size} is not supported by TritonAttention. "
|
||||||
f"Supported head sizes are: {support_head_sizes}.")
|
f"Supported head sizes are: {support_head_sizes}.")
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
"are not implemented for "
|
"are not implemented for "
|
||||||
"ROCmAttentionImpl")
|
"TritonAttentionImpl")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
Reference in New Issue
Block a user