Adapt to the SDPA interface to enable the NPU to call FlashAttentionScore (#41143)

Adapt to the SDPA interface to enable the NPU to call FlashAttentionScore.

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
魅影
2025-09-30 22:19:57 +08:00
committed by GitHub
parent cf0887f62c
commit 2dd175e6bb

View File

@ -2,7 +2,7 @@ from typing import Optional
import torch
from ..utils import is_torch_xpu_available, logging
from ..utils import is_torch_npu_available, is_torch_xpu_available, logging
from ..utils.import_utils import is_torch_greater_or_equal
@ -12,6 +12,7 @@ logger = logging.get_logger(__name__)
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
_is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
_is_torch_xpu_available = is_torch_xpu_available()
_is_torch_npu_available = is_torch_npu_available()
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -35,8 +36,12 @@ def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -
# 2.xpu
# - torch version >= 2.8
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
# 3.npu
# - npu is not supported gqa currently
if _is_torch_xpu_available:
return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy)
if _is_torch_npu_available:
return False
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
@ -80,6 +85,14 @@ def sdpa_attention_forward(
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
# When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator
# and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
# This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
if _is_torch_npu_available:
if attention_mask is not None and attention_mask.dtype != torch.bool:
# Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,