[export] set enable_gqa in export flash->math decomp (#158604)

Differential Revision: D78524147

For `scaled_dot_product_attention(..., enable_gqa=True)`:
- the Math backend passes the flag through, performing the extra [KV broadcast](6e07d6a0ff/aten/src/ATen/native/transformers/attention.cpp (L902)) if set to True
- the Flash backend has no flag, and relies on correct indexing in the C++ kernel
- Export used to default to Math for `enable_gqa=True`, but https://github.com/pytorch/pytorch/pull/157893 landed and enabled Flash. At the same time, there's an export-only [decomp](6e07d6a0ff/torch/_decomp/decompositions.py (L4968)) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs.

This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158604
Approved by: https://github.com/angelayi, https://github.com/drisspg
This commit is contained in:
Pian Pawakapan
2025-07-24 14:46:13 +00:00
committed by PyTorch MergeBot
parent f55c5d085e
commit 48fe4ff247
2 changed files with 104 additions and 0 deletions

View File

@ -15106,6 +15106,109 @@ def forward(self, args_0):
return (abs_1,)""",
)
def test_sdpa_gqa(self):
from torch.nn.attention import sdpa_kernel, SDPBackend
class Foo(torch.nn.Module):
def forward(self, q, k, v):
return F.scaled_dot_product_attention(q, k, v, enable_gqa=True)
q = torch.randn(1, 32, 256, 128)
k = torch.randn(1, 8, 256, 128)
v = torch.randn(1, 8, 256, 128)
with sdpa_kernel(SDPBackend.MATH):
ep_math = export(Foo(), (q, k, v))
ep_math = ep_math.run_decompositions()
self.assertExpectedInline(
ep_math.graph_module.code.strip(),
"""\
def forward(self, q, k, v):
mul = torch.ops.aten.mul.Scalar(q, 0.29730177875068026); q = None
unsqueeze = torch.ops.aten.unsqueeze.default(k, 2); k = None
expand = torch.ops.aten.expand.default(unsqueeze, [1, 8, 4, 256, 128]); unsqueeze = None
clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None
view = torch.ops.aten.view.default(clone, [1, 32, 256, 128]); clone = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(v, 2); v = None
expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [1, 8, 4, 256, 128]); unsqueeze_1 = None
clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None
view_1 = torch.ops.aten.view.default(clone_1, [1, 32, 256, 128]); clone_1 = None
permute = torch.ops.aten.permute.default(view, [0, 1, 3, 2]); view = None
mul_1 = torch.ops.aten.mul.Scalar(permute, 0.29730177875068026); permute = None
expand_2 = torch.ops.aten.expand.default(mul, [1, 32, 256, 128]); mul = None
view_2 = torch.ops.aten.view.default(expand_2, [32, 256, 128]); expand_2 = None
expand_3 = torch.ops.aten.expand.default(mul_1, [1, 32, 128, 256]); mul_1 = None
view_3 = torch.ops.aten.view.default(expand_3, [32, 128, 256]); expand_3 = None
bmm = torch.ops.aten.bmm.default(view_2, view_3); view_2 = view_3 = None
view_4 = torch.ops.aten.view.default(bmm, [1, 32, 256, 256]); bmm = None
_softmax = torch.ops.aten._softmax.default(view_4, -1, False)
eq = torch.ops.aten.eq.Scalar(view_4, -inf); view_4 = None
logical_not = torch.ops.aten.logical_not.default(eq); eq = None
any_1 = torch.ops.aten.any.dim(logical_not, -1, True); logical_not = None
logical_not_1 = torch.ops.aten.logical_not.default(any_1); any_1 = None
full_like = torch.ops.aten.full_like.default(_softmax, 0, pin_memory = False, memory_format = torch.preserve_format)
where = torch.ops.aten.where.self(logical_not_1, full_like, _softmax); logical_not_1 = full_like = _softmax = None
expand_4 = torch.ops.aten.expand.default(where, [1, 32, 256, 256]); where = None
view_5 = torch.ops.aten.view.default(expand_4, [32, 256, 256]); expand_4 = None
expand_5 = torch.ops.aten.expand.default(view_1, [1, 32, 256, 128]); view_1 = None
view_6 = torch.ops.aten.view.default(expand_5, [32, 256, 128]); expand_5 = None
bmm_1 = torch.ops.aten.bmm.default(view_5, view_6); view_5 = view_6 = None
view_7 = torch.ops.aten.view.default(bmm_1, [1, 32, 256, 128]); bmm_1 = None
return (view_7,)""",
)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
ep_flash = export(Foo(), (q, k, v))
ep_flash = ep_flash.run_decompositions()
self.assertExpectedInline(
ep_flash.graph_module.code.strip(),
"""\
def forward(self, q, k, v):
mul = torch.ops.aten.mul.Scalar(q, 0.29730177875068026); q = None
unsqueeze = torch.ops.aten.unsqueeze.default(k, 2); k = None
expand = torch.ops.aten.expand.default(unsqueeze, [1, 8, 4, 256, 128]); unsqueeze = None
clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None
view = torch.ops.aten.view.default(clone, [1, 32, 256, 128]); clone = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(v, 2); v = None
expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [1, 8, 4, 256, 128]); unsqueeze_1 = None
clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None
view_1 = torch.ops.aten.view.default(clone_1, [1, 32, 256, 128]); clone_1 = None
permute = torch.ops.aten.permute.default(view, [0, 1, 3, 2]); view = None
mul_1 = torch.ops.aten.mul.Scalar(permute, 0.29730177875068026); permute = None
expand_2 = torch.ops.aten.expand.default(mul, [1, 32, 256, 128]); mul = None
view_2 = torch.ops.aten.view.default(expand_2, [32, 256, 128]); expand_2 = None
expand_3 = torch.ops.aten.expand.default(mul_1, [1, 32, 128, 256]); mul_1 = None
view_3 = torch.ops.aten.view.default(expand_3, [32, 128, 256]); expand_3 = None
bmm = torch.ops.aten.bmm.default(view_2, view_3); view_2 = view_3 = None
view_4 = torch.ops.aten.view.default(bmm, [1, 32, 256, 256]); bmm = None
_softmax = torch.ops.aten._softmax.default(view_4, -1, False)
eq = torch.ops.aten.eq.Scalar(view_4, -inf); view_4 = None
logical_not = torch.ops.aten.logical_not.default(eq); eq = None
any_1 = torch.ops.aten.any.dim(logical_not, -1, True); logical_not = None
logical_not_1 = torch.ops.aten.logical_not.default(any_1); any_1 = None
full_like = torch.ops.aten.full_like.default(_softmax, 0, pin_memory = False, memory_format = torch.preserve_format)
where = torch.ops.aten.where.self(logical_not_1, full_like, _softmax); logical_not_1 = full_like = _softmax = None
expand_4 = torch.ops.aten.expand.default(where, [1, 32, 256, 256]); where = None
view_5 = torch.ops.aten.view.default(expand_4, [32, 256, 256]); expand_4 = None
expand_5 = torch.ops.aten.expand.default(view_1, [1, 32, 256, 128]); view_1 = None
view_6 = torch.ops.aten.view.default(expand_5, [32, 256, 128]); expand_5 = None
bmm_1 = torch.ops.aten.bmm.default(view_5, view_6); view_5 = view_6 = None
view_7 = torch.ops.aten.view.default(bmm_1, [1, 32, 256, 128]); bmm_1 = None
permute_1 = torch.ops.aten.permute.default(view_7, [2, 0, 1, 3]); view_7 = None
clone_2 = torch.ops.aten.clone.default(permute_1, memory_format = torch.contiguous_format); permute_1 = None
permute_2 = torch.ops.aten.permute.default(clone_2, [1, 2, 0, 3]); clone_2 = None
return (permute_2,)""",
)
# test backend check for invalid inputs
error_type = (
RuntimeError
if is_non_strict_test(self._testMethodName)
else torch._dynamo.exc.TorchRuntimeError
)
with self.assertRaisesRegex(
error_type,
r"Number of heads in key and value must divide the number of heads",
):
export(Foo(), (torch.randn(1, 33, 256, 128), k, v))
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):

View File

@ -5075,6 +5075,7 @@ def scaled_dot_product_flash_attention_for_cpu(
is_causal=is_causal,
dropout_mask=None,
scale=scale,
enable_gqa=query.size(1) != key.size(1),
)
# Why this change?
# In pre-dispatch export scaled_dot_product_attention is executed via