Simplify GQA conditions in sdpa_attention.py (#41699)

Removed unnecessary checks for key being a torch.fx.Proxy in GQA conditions because fx tracing is no longer supported, and torch.export supports enable_gqa.
This commit is contained in:
Justin Chu
2025-10-17 09:36:38 -07:00
committed by GitHub
parent 7e204ad121
commit 347a0f9e83

View File

@ -32,13 +32,11 @@ def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -
# 1.cuda or Ascend NPU
# - torch version >= 2.5
# - attention_mask is None (otherwise it will fall back to the math kernel)
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
# 2.xpu
# - torch version >= 2.8
# - key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
if _is_torch_xpu_available:
return _is_torch_greater_or_equal_than_2_8 and not isinstance(key, torch.fx.Proxy)
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
return _is_torch_greater_or_equal_than_2_8
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None
def sdpa_attention_forward(