Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
909e6cc989 |
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_torch_xpu_available, logging
|
||||
from ..utils import is_torch_xpu_available, is_torch_npu_available, logging
|
||||
from ..utils.import_utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
@ -80,6 +80,16 @@ def sdpa_attention_forward(
|
||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||
is_causal = is_causal.item()
|
||||
|
||||
# By default, when passing parameters, the sdpa interface of Ascend NPU cannot invoke the FlashAttentionScore operator
|
||||
# but instead uses internal small operator concatenation. To enter FlashAttentionScore, the following conditions must be met:
|
||||
# enable is_causal and set attention_mask to None; or disable is_causal and set attention_mask to a boolean type.
|
||||
# So we adapt the parameters to allow it entry the FlashAttentionScore.
|
||||
if is_torch_npu_available():
|
||||
if is_causal:
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
|
Reference in New Issue
Block a user