mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Fix flash attention backend log (#4368)
This commit is contained in:
@ -25,7 +25,7 @@ class _Backend(enum.Enum):
|
||||
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||
backend = _which_attn_to_use(dtype)
|
||||
if backend == _Backend.FLASH_ATTN:
|
||||
logger.info("Using FlashAttention backend.")
|
||||
logger.info("Using FlashAttention-2 backend.")
|
||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||
FlashAttentionBackend)
|
||||
return FlashAttentionBackend
|
||||
@ -62,12 +62,12 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||
# NVIDIA GPUs.
|
||||
if torch.cuda.get_device_capability()[0] < 8:
|
||||
# Volta and Turing NVIDIA GPUs.
|
||||
logger.info("Cannot use FlashAttention backend for Volta and Turing "
|
||||
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
|
||||
"GPUs.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
logger.info("Cannot use FlashAttention backend for dtype other than "
|
||||
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
|
||||
"torch.float16 or torch.bfloat16.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
@ -75,8 +75,8 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||
import flash_attn # noqa: F401
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"Cannot use FlashAttention backend because the flash_attn package "
|
||||
"is not found. Please install it for better performance.")
|
||||
"Cannot use FlashAttention-2 backend because the flash_attn "
|
||||
"package is not found. Please install it for better performance.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
|
Reference in New Issue
Block a user