mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[platform] Allow platform specify attention backend (#11609)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@ -1,10 +1,10 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import which_attn_to_use
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
from vllm.platforms.openvino import OpenVinoPlatform
|
||||
@ -12,6 +12,13 @@ from vllm.platforms.rocm import RocmPlatform
|
||||
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
||||
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
||||
@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "TORCH_SDPA"
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "ROCM_FLASH"
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.get_name() == "ROCM_FLASH"
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
OpenVinoPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "OPENVINO"
|
||||
OpenVinoPlatform()), patch.dict('sys.modules',
|
||||
{'openvino': Mock()}):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.get_name() == "OPENVINO"
|
||||
else:
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == name
|
||||
if name in ["XFORMERS", "FLASHINFER"]:
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
assert backend.get_name() == name
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch):
|
||||
"""Test FlashAttn validation."""
|
||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||
# which_attn_to_use
|
||||
# get_attn_backend
|
||||
|
||||
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
|
||||
|
||||
# Unsupported CUDA arch
|
||||
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
||||
backend = which_attn_to_use(16, torch.float16, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported data type
|
||||
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported block size
|
||||
backend = which_attn_to_use(16, torch.float16, None, 8, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||
backend = which_attn_to_use(16, torch.float16, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported head size
|
||||
backend = which_attn_to_use(17, torch.float16, None, 16, False)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Attention-free models should bypass env and use PlaceholderAttention
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch):
|
||||
"""Throw an exception if the backend name is invalid."""
|
||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||
with pytest.raises(ValueError):
|
||||
which_attn_to_use(16, torch.float16, None, 16, False)
|
||||
get_attn_backend(16, torch.float16, None, 16, False)
|
||||
|
@ -9,7 +9,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -114,83 +114,19 @@ def _cached_get_attn_backend(
|
||||
BlocksparseFlashAttentionBackend)
|
||||
return BlocksparseFlashAttentionBackend
|
||||
|
||||
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
|
||||
is_attention_free, use_v1)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using Flash Attention backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
if backend == _Backend.FLASH_ATTN_VLLM_V1:
|
||||
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend as FlashAttentionBackendV1)
|
||||
return FlashAttentionBackendV1
|
||||
if backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||
XFormersBackend)
|
||||
return XFormersBackend
|
||||
elif backend == _Backend.ROCM_FLASH:
|
||||
logger.info("Using ROCmFlashAttention backend.")
|
||||
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
|
||||
ROCmFlashAttentionBackend)
|
||||
return ROCmFlashAttentionBackend
|
||||
elif backend == _Backend.TORCH_SDPA:
|
||||
assert current_platform.is_cpu(), RuntimeError(
|
||||
"Torch SDPA backend is only used for the CPU device.")
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||
return TorchSDPABackend
|
||||
elif backend == _Backend.OPENVINO:
|
||||
logger.info("Using OpenVINO Attention backend.")
|
||||
from vllm.attention.backends.openvino import OpenVINOAttentionBackend
|
||||
return OpenVINOAttentionBackend
|
||||
elif backend == _Backend.IPEX:
|
||||
assert current_platform.is_xpu(), RuntimeError(
|
||||
"IPEX attention backend is only used for the XPU device.")
|
||||
logger.info("Using IPEX attention backend.")
|
||||
from vllm.attention.backends.ipex_attn import IpexAttnBackend
|
||||
return IpexAttnBackend
|
||||
elif backend == _Backend.FLASHINFER:
|
||||
logger.info("Using Flashinfer backend.")
|
||||
from vllm.attention.backends.flashinfer import FlashInferBackend
|
||||
return FlashInferBackend
|
||||
elif backend == _Backend.HPU_ATTN:
|
||||
logger.info("Using HPUAttention backend.")
|
||||
from vllm.attention.backends.hpu_attn import HPUAttentionBackend
|
||||
return HPUAttentionBackend
|
||||
elif backend == _Backend.PALLAS:
|
||||
logger.info("Using Pallas backend.")
|
||||
from vllm.attention.backends.pallas import PallasAttentionBackend
|
||||
return PallasAttentionBackend
|
||||
elif backend == _Backend.NO_ATTENTION:
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
else:
|
||||
raise ValueError("Invalid attention backend.")
|
||||
|
||||
|
||||
def which_attn_to_use(head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False) -> _Backend:
|
||||
"""Returns which flash attention backend to use."""
|
||||
# Default case.
|
||||
selected_backend = _Backend.FLASH_ATTN
|
||||
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
return _Backend.NO_ATTENTION
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionBackend)
|
||||
return PlaceholderAttentionBackend
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
@ -201,64 +137,13 @@ def which_attn_to_use(head_size: int,
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
|
||||
# get device-specific default attn_backend
|
||||
default_backend = current_platform.get_default_attn_backend(
|
||||
selected_backend)
|
||||
if default_backend is not None:
|
||||
return default_backend
|
||||
|
||||
if use_v1:
|
||||
return _Backend.FLASH_ATTN_VLLM_V1
|
||||
|
||||
# FlashAttn in NVIDIA GPUs.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
if not current_platform.has_device_capability(80):
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
"GPUs.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||
logger.warning(
|
||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||
"better performance by setting environment variable "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
elif block_size % 16 != 0:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for block size not "
|
||||
"divisible by 16.")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
# FlashAttn is valid for the model, checking if the package is installed.
|
||||
if selected_backend == _Backend.FLASH_ATTN:
|
||||
try:
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
|
||||
supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in supported_sizes:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||
head_size)
|
||||
selected_backend = _Backend.XFORMERS
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend because the "
|
||||
"vllm.vllm_flash_attn package is not found. "
|
||||
"Make sure that vllm_flash_attn was built and installed "
|
||||
"(on by default).")
|
||||
selected_backend = _Backend.XFORMERS
|
||||
|
||||
return selected_backend
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
@ -28,10 +28,13 @@ class CpuPlatform(Platform):
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
return _Backend.TORCH_SDPA
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
|
@ -16,7 +16,7 @@ import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@ -141,6 +141,81 @@ class CudaPlatformBase(Platform):
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1) -> str:
|
||||
if use_v1:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info("Using FlashInfer backend.")
|
||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
pass
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}")
|
||||
|
||||
target_backend = _Backend.FLASH_ATTN
|
||||
if not cls.has_device_capability(80):
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
"GPUs.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
elif dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
elif kv_cache_dtype is not None and \
|
||||
kv_cache_dtype.startswith("fp8"):
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
|
||||
logger.warning(
|
||||
"Please use FlashInfer backend with FP8 KV Cache for "
|
||||
"better performance by setting environment variable "
|
||||
"VLLM_ATTENTION_BACKEND=FLASHINFER")
|
||||
target_backend = _Backend.XFORMERS
|
||||
elif block_size % 16 != 0:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for block size not "
|
||||
"divisible by 16.")
|
||||
target_backend = _Backend.XFORMERS
|
||||
|
||||
# FlashAttn is valid for the model, checking if the package is
|
||||
# installed.
|
||||
if target_backend == _Backend.FLASH_ATTN:
|
||||
try:
|
||||
import vllm.vllm_flash_attn # noqa: F401
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
|
||||
supported_sizes = \
|
||||
FlashAttentionBackend.get_supported_head_sizes()
|
||||
if head_size not in supported_sizes:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend for head size %d.",
|
||||
head_size)
|
||||
target_backend = _Backend.XFORMERS
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention-2 backend because the "
|
||||
"vllm.vllm_flash_attn package is not found. "
|
||||
"Make sure that vllm_flash_attn was built and installed "
|
||||
"(on by default).")
|
||||
target_backend = _Backend.XFORMERS
|
||||
|
||||
if target_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||
|
||||
logger.info("Using Flash Attention backend.")
|
||||
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
@ -21,8 +21,11 @@ class HpuPlatform(Platform):
|
||||
dispatch_key: str = "HPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
return _Backend.HPU_ATTN
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
logger.info("Using HPUAttention backend.")
|
||||
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
|
@ -112,9 +112,11 @@ class Platform:
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend):
|
||||
"""Get the default attention backend of a device."""
|
||||
return None
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
|
@ -28,10 +28,13 @@ class OpenVinoPlatform(Platform):
|
||||
dispatch_key: str = "CPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
return _Backend.OPENVINO
|
||||
logger.info("Using OpenVINO Attention backend.")
|
||||
return "vllm.attention.backends.openvino.OpenVINOAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
|
@ -70,7 +70,8 @@ class RocmPlatform(Platform):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1) -> str:
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
@ -79,7 +80,8 @@ class RocmPlatform(Platform):
|
||||
logger.info("flash_attn is not supported on NAVI GPUs.")
|
||||
else:
|
||||
logger.info("%s is not supported in AMD GPUs.", selected_backend)
|
||||
return _Backend.ROCM_FLASH
|
||||
logger.info("Using ROCmFlashAttention backend.")
|
||||
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=8)
|
||||
|
@ -24,10 +24,13 @@ class TpuPlatform(Platform):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
return _Backend.PALLAS
|
||||
logger.info("Using Pallas backend.")
|
||||
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
|
@ -21,10 +21,13 @@ class XPUPlatform(Platform):
|
||||
dispatch_key: str = "XPU"
|
||||
|
||||
@classmethod
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
return _Backend.IPEX
|
||||
logger.info("Using IPEX attention backend.")
|
||||
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||
|
Reference in New Issue
Block a user