mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Add get_name
method to attention backends (#4685)
This commit is contained in:
@ -9,6 +9,11 @@ import torch
|
||||
class AttentionBackend(ABC):
|
||||
"""Abstract class for attention backends."""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_impl_cls() -> Type["AttentionImpl"]:
|
||||
|
@ -19,6 +19,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "flash-attn"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashAttentionImpl"]:
|
||||
return FlashAttentionImpl
|
||||
|
@ -1,16 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
except ImportError:
|
||||
flashinfer = None
|
||||
flash_attn_varlen_func = None
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
@ -20,6 +14,10 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "flashinfer"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashInferImpl"]:
|
||||
return FlashInferImpl
|
||||
|
@ -17,6 +17,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "rocm-flash-attn"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
||||
return ROCmFlashAttentionImpl
|
||||
|
@ -15,6 +15,10 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "torch-sdpa"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
@ -20,6 +20,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class XFormersBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "xformers"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["XFormersImpl"]:
|
||||
return XFormersImpl
|
||||
|
@ -9,7 +9,6 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||
get_attn_backend)
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||
@ -395,7 +394,7 @@ class ModelRunner:
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
if self.attn_backend is FlashInferBackend:
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
use_cuda_graph=False,
|
||||
@ -556,7 +555,7 @@ class ModelRunner:
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.attn_backend is FlashInferBackend:
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
if not hasattr(self, "flashinfer_workspace_buffer"):
|
||||
# Allocate 16MB workspace buffer
|
||||
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
||||
|
Reference in New Issue
Block a user