mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Adapt to the SDPA interface to enable the NPU to call FlashAttentionScore (#41143)
Adapt to the SDPA interface to enable the NPU to call FlashAttentionScore. Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_torch_xpu_available, logging
|
||||
from ..utils import is_torch_npu_available, is_torch_xpu_available, logging
|
||||
from ..utils.import_utils import is_torch_greater_or_equal
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ logger = logging.get_logger(__name__)
|
||||
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
|
||||
_is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
|
||||
_is_torch_xpu_available = is_torch_xpu_available()
|
||||
_is_torch_npu_available = is_torch_npu_available()
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
@ -35,8 +36,12 @@ def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -
|
||||
# 2.xpu
|
||||
# - torch version >= 2.8
|
||||
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
|
||||
# 3.npu
|
||||
# - npu is not supported gqa currently
|
||||
if _is_torch_xpu_available:
|
||||
return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy)
|
||||
if _is_torch_npu_available:
|
||||
return False
|
||||
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
|
||||
|
||||
|
||||
@ -80,6 +85,14 @@ def sdpa_attention_forward(
|
||||
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
|
||||
is_causal = is_causal.item()
|
||||
|
||||
# When `is_causal = False` and the `attention_mask` is not of boolean type, the Ascend NPU's SDPA interface cannot utilize the FlashAttentionScore operator,
|
||||
# and falls back to small-operator concatenation. To invoke the FlashAttentionScore, the attention_mask must be converted to boolean type.
|
||||
# This adaptation ensures the `attention_mask` meets the requirement for using FlashAttentionScore.
|
||||
if _is_torch_npu_available:
|
||||
if attention_mask is not None and attention_mask.dtype != torch.bool:
|
||||
# Convert to boolean type, making sdpa to force call FlashAttentionScore to improve performance.
|
||||
attention_mask = torch.logical_not(attention_mask.bool()).to(query.device)
|
||||
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
|
Reference in New Issue
Block a user