[Misc] Add get_name method to attention backends (#4685)

This commit is contained in:
Woosuk Kwon
2024-05-08 09:59:31 -07:00
committed by GitHub
parent 0f9a6e3d22
commit 5510cf0e8a
7 changed files with 30 additions and 12 deletions

View File

@ -9,6 +9,11 @@ import torch
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:

View File

@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flash-attn"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl

View File

@ -1,16 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type
try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None
import flashinfer
import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
class FlashInferBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flashinfer"
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl

View File

@ -17,6 +17,10 @@ logger = init_logger(__name__)
class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "rocm-flash-attn"
@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl

View File

@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "torch-sdpa"
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl

View File

@ -20,6 +20,10 @@ logger = init_logger(__name__)
class XFormersBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "xformers"
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl

View File

@ -9,7 +9,6 @@ import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
@ -395,7 +394,7 @@ class ModelRunner:
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
use_cuda_graph=False,
@ -556,7 +555,7 @@ class ModelRunner:
device=self.device,
)
if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html