Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
909e6cc989 |
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..utils import is_torch_xpu_available, logging
|
from ..utils import is_torch_xpu_available, is_torch_npu_available, logging
|
||||||
from ..utils.import_utils import is_torch_greater_or_equal
|
from ..utils.import_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
|
|
||||||
@ -80,6 +80,16 @@ def sdpa_attention_forward(
|
|||||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||||
is_causal = is_causal.item()
|
is_causal = is_causal.item()
|
||||||
|
|
||||||
|
# By default, when passing parameters, the sdpa interface of Ascend NPU cannot invoke the FlashAttentionScore operator
|
||||||
|
# but instead uses internal small operator concatenation. To enter FlashAttentionScore, the following conditions must be met:
|
||||||
|
# enable is_causal and set attention_mask to None; or disable is_causal and set attention_mask to a boolean type.
|
||||||
|
# So we adapt the parameters to allow it entry the FlashAttentionScore.
|
||||||
|
if is_torch_npu_available():
|
||||||
|
if is_causal:
|
||||||
|
attention_mask = None
|
||||||
|
else:
|
||||||
|
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
|
||||||
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
Reference in New Issue
Block a user