[Platform] Do not raise error if _Backend is not found (#12023)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
wangxiyuan
2025-01-15 18:14:15 +08:00
committed by GitHub
parent ad388d25a8
commit 3adf0ffda8
6 changed files with 49 additions and 16 deletions

View File

@ -94,7 +94,12 @@ def test_flash_attn(monkeypatch):
def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
get_attn_backend(16, torch.float16, None, 16, False)
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(32, torch.float16, None, 16, False)
assert backend.get_name() == "FLASH_ATTN"
# when block size == 16, backend will fall back to XFORMERS
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"

View File

@ -0,0 +1,8 @@
from vllm.attention.backends.flash_attn import FlashAttentionBackend
class DummyAttentionBackend(FlashAttentionBackend):
@staticmethod
def get_name() -> str:
return "Dummy_Backend"

View File

@ -3,3 +3,7 @@ from vllm.platforms.cuda import CudaPlatform
class DummyPlatform(CudaPlatform):
device_name = "DummyDevice"
def get_attn_backend_cls(self, backend_name, head_size, dtype,
kv_cache_dtype, block_size, use_v1):
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501

View File

@ -1,3 +1,10 @@
import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import get_attn_backend
from vllm.utils import STR_INVALID_VAL
def test_platform_plugins():
# simulate workload by running an example
import runpy
@ -14,3 +21,10 @@ def test_platform_plugins():
f"Expected DummyDevice, got {current_platform.device_name}, "
"possibly because current_platform is imported before the plugin"
f" is loaded. The first import:\n{_init_trace}")
def test_oot_attention_backend(monkeypatch):
# ignore the backend env variable if it is set
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
assert backend.get_name() == "Dummy_Backend"

View File

@ -190,11 +190,11 @@ class MultiHeadAttention(nn.Module):
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
attn_backend = backend_name_to_enum(attn_backend.get_name())
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
attn_backend = _Backend.XFORMERS
backend = backend_name_to_enum(attn_backend.get_name())
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
backend = _Backend.XFORMERS
self.attn_backend = attn_backend if attn_backend in {
self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA, _Backend.XFORMERS
} else _Backend.TORCH_SDPA

View File

@ -14,16 +14,18 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__)
def backend_name_to_enum(backend_name: str) -> _Backend:
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.
Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert backend_name is not None
backend_members = _Backend.__members__
if backend_name not in backend_members:
raise ValueError(f"Invalid attention backend '{backend_name}'. "
f"Available backends: {', '.join(backend_members)} "
"(case-sensitive).")
return _Backend[backend_name]
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
None
def get_env_variable_attn_backend() -> Optional[_Backend]: