mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Platform] Do not raise error if _Backend is not found (#12023)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@ -94,7 +94,12 @@ def test_flash_attn(monkeypatch):
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch):
|
||||
"""Throw an exception if the backend name is invalid."""
|
||||
"""Ignore the invalid env variable if it is set."""
|
||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||
with pytest.raises(ValueError):
|
||||
get_attn_backend(16, torch.float16, None, 16, False)
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(32, torch.float16, None, 16, False)
|
||||
assert backend.get_name() == "FLASH_ATTN"
|
||||
|
||||
# when block size == 16, backend will fall back to XFORMERS
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() == "XFORMERS"
|
||||
|
@ -0,0 +1,8 @@
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
||||
|
||||
|
||||
class DummyAttentionBackend(FlashAttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "Dummy_Backend"
|
@ -3,3 +3,7 @@ from vllm.platforms.cuda import CudaPlatform
|
||||
|
||||
class DummyPlatform(CudaPlatform):
|
||||
device_name = "DummyDevice"
|
||||
|
||||
def get_attn_backend_cls(self, backend_name, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1):
|
||||
return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501
|
||||
|
@ -1,3 +1,10 @@
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.utils import STR_INVALID_VAL
|
||||
|
||||
|
||||
def test_platform_plugins():
|
||||
# simulate workload by running an example
|
||||
import runpy
|
||||
@ -14,3 +21,10 @@ def test_platform_plugins():
|
||||
f"Expected DummyDevice, got {current_platform.device_name}, "
|
||||
"possibly because current_platform is imported before the plugin"
|
||||
f" is loaded. The first import:\n{_init_trace}")
|
||||
|
||||
|
||||
def test_oot_attention_backend(monkeypatch):
|
||||
# ignore the backend env variable if it is set
|
||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert backend.get_name() == "Dummy_Backend"
|
||||
|
@ -190,11 +190,11 @@ class MultiHeadAttention(nn.Module):
|
||||
kv_cache_dtype=None,
|
||||
block_size=16,
|
||||
is_attention_free=False)
|
||||
attn_backend = backend_name_to_enum(attn_backend.get_name())
|
||||
if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
attn_backend = _Backend.XFORMERS
|
||||
backend = backend_name_to_enum(attn_backend.get_name())
|
||||
if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}:
|
||||
backend = _Backend.XFORMERS
|
||||
|
||||
self.attn_backend = attn_backend if attn_backend in {
|
||||
self.attn_backend = backend if backend in {
|
||||
_Backend.TORCH_SDPA, _Backend.XFORMERS
|
||||
} else _Backend.TORCH_SDPA
|
||||
|
||||
|
@ -14,16 +14,18 @@ from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def backend_name_to_enum(backend_name: str) -> _Backend:
|
||||
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
|
||||
"""
|
||||
Convert a string backend name to a _Backend enum value.
|
||||
|
||||
Returns:
|
||||
* _Backend: enum value if backend_name is a valid in-tree type
|
||||
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
|
||||
loaded.
|
||||
"""
|
||||
assert backend_name is not None
|
||||
|
||||
backend_members = _Backend.__members__
|
||||
if backend_name not in backend_members:
|
||||
raise ValueError(f"Invalid attention backend '{backend_name}'. "
|
||||
f"Available backends: {', '.join(backend_members)} "
|
||||
"(case-sensitive).")
|
||||
|
||||
return _Backend[backend_name]
|
||||
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
|
||||
None
|
||||
|
||||
|
||||
def get_env_variable_attn_backend() -> Optional[_Backend]:
|
||||
|
Reference in New Issue
Block a user