mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[bugfix] fix flash_attention_2 unavailable error on Ascend NPU (#39844)
This commit is contained in:
@ -267,3 +267,8 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
|
||||
return npu_rotary_mul(x, cos, sin)
|
||||
|
||||
|
||||
def get_npu_flash_attn_funcs():
|
||||
# return flash attention related functions used for Ascend NPU in order
|
||||
return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False
|
||||
|
@ -261,7 +261,7 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = Non
|
||||
|
||||
def _lazy_imports(impl: Optional[str]):
|
||||
# returns funcs and pad/unpad based on impl
|
||||
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
|
||||
is_fa2 = is_flash_attn_2_available()
|
||||
is_fa3 = is_flash_attn_3_available()
|
||||
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
|
||||
try:
|
||||
@ -299,7 +299,12 @@ def _lazy_imports(impl: Optional[str]):
|
||||
raise ImportError(
|
||||
"Failed to import flash attention 2, please install it or use another implementation."
|
||||
) from e
|
||||
if impl == "flash_attention_3" or (impl is None and is_fa3):
|
||||
elif is_torch_npu_available():
|
||||
# get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU
|
||||
from .integrations.npu_flash_attention import get_npu_flash_attn_funcs
|
||||
|
||||
return get_npu_flash_attn_funcs()
|
||||
elif impl == "flash_attention_3" or (impl is None and is_fa3):
|
||||
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
|
||||
|
||||
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input
|
||||
|
@ -2483,8 +2483,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
|
||||
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
|
||||
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
|
||||
|
||||
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
|
||||
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
|
||||
# package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
|
||||
if is_torch_npu_available():
|
||||
logger.info("Detect using FlashAttention2 on Ascend NPU.")
|
||||
return True
|
||||
|
||||
if importlib.util.find_spec("flash_attn") is None:
|
||||
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
|
||||
else:
|
||||
# Check FA2 installed version compatibility
|
||||
|
Reference in New Issue
Block a user