mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Simplify GQA conditions in sdpa_attention.py (#41699)
Removed unnecessary checks for key being a torch.fx.Proxy in GQA conditions because fx tracing is no longer supported, and torch.export supports enable_gqa.
This commit is contained in:
@ -32,13 +32,11 @@ def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -
|
||||
# 1.cuda or Ascend NPU
|
||||
# - torch version >= 2.5
|
||||
# - attention_mask is None (otherwise it will fall back to the math kernel)
|
||||
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
|
||||
# 2.xpu
|
||||
# - torch version >= 2.8
|
||||
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
|
||||
if _is_torch_xpu_available:
|
||||
return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy)
|
||||
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
|
||||
return _is_torch_greater_or_equal_than_2_8
|
||||
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None
|
||||
|
||||
|
||||
def sdpa_attention_forward(
|
||||
|
Reference in New Issue
Block a user