mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Quick Fix by adding conditional import for flash_attn_varlen_func in flash_attn (#20143)
Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
This commit is contained in:
@ -66,3 +66,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return get_flash_attn_version() == 3 and \
|
||||
current_platform.get_device_capability().major == 9
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||
|
@ -14,10 +14,14 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
flash_attn_varlen_func,
|
||||
get_flash_attn_version,
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash)
|
||||
is_flash_attn_varlen_func_available)
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash)
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
|
Reference in New Issue
Block a user