[bugfix] fix flash_attention_2 unavailable error on Ascend NPU (#39844)

This commit is contained in:
Zhen
2025-08-07 01:48:52 +08:00
committed by GitHub
parent cf243a1bf8
commit ac0b468465
3 changed files with 18 additions and 4 deletions

View File

@ -267,3 +267,8 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
sin = sin.unsqueeze(0).unsqueeze(2)
return npu_rotary_mul(x, cos, sin)
def get_npu_flash_attn_funcs():
# return flash attention related functions used for Ascend NPU in order
return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False

View File

@ -261,7 +261,7 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = Non
def _lazy_imports(impl: Optional[str]):
# returns funcs and pad/unpad based on impl
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
is_fa2 = is_flash_attn_2_available()
is_fa3 = is_flash_attn_3_available()
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
try:
@ -299,7 +299,12 @@ def _lazy_imports(impl: Optional[str]):
raise ImportError(
"Failed to import flash attention 2, please install it or use another implementation."
) from e
if impl == "flash_attention_3" or (impl is None and is_fa3):
elif is_torch_npu_available():
# get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU
from .integrations.npu_flash_attention import get_npu_flash_attn_funcs
return get_npu_flash_attn_funcs()
elif impl == "flash_attention_3" or (impl is None and is_fa3):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input

View File

@ -2483,8 +2483,12 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
# package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
if is_torch_npu_available():
logger.info("Detect using FlashAttention2 on Ascend NPU.")
return True
if importlib.util.find_spec("flash_attn") is None:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
else:
# Check FA2 installed version compatibility