1 Commits

View File

@ -2,7 +2,7 @@ from typing import Optional
import torch import torch
from ..utils import is_torch_xpu_available, logging from ..utils import is_torch_xpu_available, is_torch_npu_available, logging
from ..utils.import_utils import is_torch_greater_or_equal from ..utils.import_utils import is_torch_greater_or_equal
@ -80,6 +80,16 @@ def sdpa_attention_forward(
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item() is_causal = is_causal.item()
# By default, when passing parameters, the sdpa interface of Ascend NPU cannot invoke the FlashAttentionScore operator
# but instead uses internal small operator concatenation. To enter FlashAttentionScore, the following conditions must be met:
# enable is_causal and set attention_mask to None; or disable is_causal and set attention_mask to a boolean type.
# So we adapt the parameters to allow it entry the FlashAttentionScore.
if is_torch_npu_available():
if is_causal:
attention_mask = None
else:
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
attn_output = torch.nn.functional.scaled_dot_product_attention( attn_output = torch.nn.functional.scaled_dot_product_attention(
query, query,
key, key,