From 347a0f9e836386e508d76eecfb644e15a5b49a5a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 17 Oct 2025 09:36:38 -0700 Subject: [PATCH] Simplify GQA conditions in sdpa_attention.py (#41699) Removed unnecessary checks for key being a torch.fx.Proxy in GQA conditions because fx tracing is no longer supported, and torch.export supports enable_gqa. --- src/transformers/integrations/sdpa_attention.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 0526d135d0c..db36dfc3033 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -32,13 +32,11 @@ def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) - # 1.cuda or Ascend NPU # - torch version >= 2.5 # - attention_mask is None (otherwise it will fall back to the math kernel) - # - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error) # 2.xpu # - torch version >= 2.8 - # - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error) if _is_torch_xpu_available: - return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy) - return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy) + return _is_torch_greater_or_equal_than_2_8 + return _is_torch_greater_or_equal_than_2_5 and attention_mask is None def sdpa_attention_forward(