[export] set enable_gqa in export flash->math decomp (#158604)

Differential Revision: D78524147

For `scaled_dot_product_attention(..., enable_gqa=True)`:
- the Math backend passes the flag through, performing the extra [KV broadcast](6e07d6a0ff/aten/src/ATen/native/transformers/attention.cpp (L902)) if set to True
- the Flash backend has no flag, and relies on correct indexing in the C++ kernel
- Export used to default to Math for `enable_gqa=True`, but https://github.com/pytorch/pytorch/pull/157893 landed and enabled Flash. At the same time, there's an export-only [decomp](6e07d6a0ff/torch/_decomp/decompositions.py (L4968)) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs.

This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158604
Approved by: https://github.com/angelayi, https://github.com/drisspg
This commit is contained in:
Pian Pawakapan
2025-07-24 14:46:13 +00:00
committed by PyTorch MergeBot
parent f55c5d085e
commit 48fe4ff247
2 changed files with 104 additions and 0 deletions

View File

@ -5075,6 +5075,7 @@ def scaled_dot_product_flash_attention_for_cpu(
is_causal=is_causal,
dropout_mask=None,
scale=scale,
enable_gqa=query.size(1) != key.size(1),
)
# Why this change?
# In pre-dispatch export scaled_dot_product_attention is executed via