mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] set enable_gqa in export flash->math decomp (#158604)
Differential Revision: D78524147 For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](6e07d6a0ff/aten/src/ATen/native/transformers/attention.cpp (L902)
) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but https://github.com/pytorch/pytorch/pull/157893 landed and enabled Flash. At the same time, there's an export-only [decomp](6e07d6a0ff/torch/_decomp/decompositions.py (L4968)
) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158604 Approved by: https://github.com/angelayi, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
f55c5d085e
commit
48fe4ff247
@ -15106,6 +15106,109 @@ def forward(self, args_0):
|
||||
return (abs_1,)""",
|
||||
)
|
||||
|
||||
def test_sdpa_gqa(self):
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, q, k, v):
|
||||
return F.scaled_dot_product_attention(q, k, v, enable_gqa=True)
|
||||
|
||||
q = torch.randn(1, 32, 256, 128)
|
||||
k = torch.randn(1, 8, 256, 128)
|
||||
v = torch.randn(1, 8, 256, 128)
|
||||
with sdpa_kernel(SDPBackend.MATH):
|
||||
ep_math = export(Foo(), (q, k, v))
|
||||
ep_math = ep_math.run_decompositions()
|
||||
self.assertExpectedInline(
|
||||
ep_math.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, q, k, v):
|
||||
mul = torch.ops.aten.mul.Scalar(q, 0.29730177875068026); q = None
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(k, 2); k = None
|
||||
expand = torch.ops.aten.expand.default(unsqueeze, [1, 8, 4, 256, 128]); unsqueeze = None
|
||||
clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None
|
||||
view = torch.ops.aten.view.default(clone, [1, 32, 256, 128]); clone = None
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(v, 2); v = None
|
||||
expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [1, 8, 4, 256, 128]); unsqueeze_1 = None
|
||||
clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None
|
||||
view_1 = torch.ops.aten.view.default(clone_1, [1, 32, 256, 128]); clone_1 = None
|
||||
permute = torch.ops.aten.permute.default(view, [0, 1, 3, 2]); view = None
|
||||
mul_1 = torch.ops.aten.mul.Scalar(permute, 0.29730177875068026); permute = None
|
||||
expand_2 = torch.ops.aten.expand.default(mul, [1, 32, 256, 128]); mul = None
|
||||
view_2 = torch.ops.aten.view.default(expand_2, [32, 256, 128]); expand_2 = None
|
||||
expand_3 = torch.ops.aten.expand.default(mul_1, [1, 32, 128, 256]); mul_1 = None
|
||||
view_3 = torch.ops.aten.view.default(expand_3, [32, 128, 256]); expand_3 = None
|
||||
bmm = torch.ops.aten.bmm.default(view_2, view_3); view_2 = view_3 = None
|
||||
view_4 = torch.ops.aten.view.default(bmm, [1, 32, 256, 256]); bmm = None
|
||||
_softmax = torch.ops.aten._softmax.default(view_4, -1, False)
|
||||
eq = torch.ops.aten.eq.Scalar(view_4, -inf); view_4 = None
|
||||
logical_not = torch.ops.aten.logical_not.default(eq); eq = None
|
||||
any_1 = torch.ops.aten.any.dim(logical_not, -1, True); logical_not = None
|
||||
logical_not_1 = torch.ops.aten.logical_not.default(any_1); any_1 = None
|
||||
full_like = torch.ops.aten.full_like.default(_softmax, 0, pin_memory = False, memory_format = torch.preserve_format)
|
||||
where = torch.ops.aten.where.self(logical_not_1, full_like, _softmax); logical_not_1 = full_like = _softmax = None
|
||||
expand_4 = torch.ops.aten.expand.default(where, [1, 32, 256, 256]); where = None
|
||||
view_5 = torch.ops.aten.view.default(expand_4, [32, 256, 256]); expand_4 = None
|
||||
expand_5 = torch.ops.aten.expand.default(view_1, [1, 32, 256, 128]); view_1 = None
|
||||
view_6 = torch.ops.aten.view.default(expand_5, [32, 256, 128]); expand_5 = None
|
||||
bmm_1 = torch.ops.aten.bmm.default(view_5, view_6); view_5 = view_6 = None
|
||||
view_7 = torch.ops.aten.view.default(bmm_1, [1, 32, 256, 128]); bmm_1 = None
|
||||
return (view_7,)""",
|
||||
)
|
||||
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||
ep_flash = export(Foo(), (q, k, v))
|
||||
ep_flash = ep_flash.run_decompositions()
|
||||
self.assertExpectedInline(
|
||||
ep_flash.graph_module.code.strip(),
|
||||
"""\
|
||||
def forward(self, q, k, v):
|
||||
mul = torch.ops.aten.mul.Scalar(q, 0.29730177875068026); q = None
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(k, 2); k = None
|
||||
expand = torch.ops.aten.expand.default(unsqueeze, [1, 8, 4, 256, 128]); unsqueeze = None
|
||||
clone = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None
|
||||
view = torch.ops.aten.view.default(clone, [1, 32, 256, 128]); clone = None
|
||||
unsqueeze_1 = torch.ops.aten.unsqueeze.default(v, 2); v = None
|
||||
expand_1 = torch.ops.aten.expand.default(unsqueeze_1, [1, 8, 4, 256, 128]); unsqueeze_1 = None
|
||||
clone_1 = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None
|
||||
view_1 = torch.ops.aten.view.default(clone_1, [1, 32, 256, 128]); clone_1 = None
|
||||
permute = torch.ops.aten.permute.default(view, [0, 1, 3, 2]); view = None
|
||||
mul_1 = torch.ops.aten.mul.Scalar(permute, 0.29730177875068026); permute = None
|
||||
expand_2 = torch.ops.aten.expand.default(mul, [1, 32, 256, 128]); mul = None
|
||||
view_2 = torch.ops.aten.view.default(expand_2, [32, 256, 128]); expand_2 = None
|
||||
expand_3 = torch.ops.aten.expand.default(mul_1, [1, 32, 128, 256]); mul_1 = None
|
||||
view_3 = torch.ops.aten.view.default(expand_3, [32, 128, 256]); expand_3 = None
|
||||
bmm = torch.ops.aten.bmm.default(view_2, view_3); view_2 = view_3 = None
|
||||
view_4 = torch.ops.aten.view.default(bmm, [1, 32, 256, 256]); bmm = None
|
||||
_softmax = torch.ops.aten._softmax.default(view_4, -1, False)
|
||||
eq = torch.ops.aten.eq.Scalar(view_4, -inf); view_4 = None
|
||||
logical_not = torch.ops.aten.logical_not.default(eq); eq = None
|
||||
any_1 = torch.ops.aten.any.dim(logical_not, -1, True); logical_not = None
|
||||
logical_not_1 = torch.ops.aten.logical_not.default(any_1); any_1 = None
|
||||
full_like = torch.ops.aten.full_like.default(_softmax, 0, pin_memory = False, memory_format = torch.preserve_format)
|
||||
where = torch.ops.aten.where.self(logical_not_1, full_like, _softmax); logical_not_1 = full_like = _softmax = None
|
||||
expand_4 = torch.ops.aten.expand.default(where, [1, 32, 256, 256]); where = None
|
||||
view_5 = torch.ops.aten.view.default(expand_4, [32, 256, 256]); expand_4 = None
|
||||
expand_5 = torch.ops.aten.expand.default(view_1, [1, 32, 256, 128]); view_1 = None
|
||||
view_6 = torch.ops.aten.view.default(expand_5, [32, 256, 128]); expand_5 = None
|
||||
bmm_1 = torch.ops.aten.bmm.default(view_5, view_6); view_5 = view_6 = None
|
||||
view_7 = torch.ops.aten.view.default(bmm_1, [1, 32, 256, 128]); bmm_1 = None
|
||||
permute_1 = torch.ops.aten.permute.default(view_7, [2, 0, 1, 3]); view_7 = None
|
||||
clone_2 = torch.ops.aten.clone.default(permute_1, memory_format = torch.contiguous_format); permute_1 = None
|
||||
permute_2 = torch.ops.aten.permute.default(clone_2, [1, 2, 0, 3]); clone_2 = None
|
||||
return (permute_2,)""",
|
||||
)
|
||||
# test backend check for invalid inputs
|
||||
error_type = (
|
||||
RuntimeError
|
||||
if is_non_strict_test(self._testMethodName)
|
||||
else torch._dynamo.exc.TorchRuntimeError
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
error_type,
|
||||
r"Number of heads in key and value must divide the number of heads",
|
||||
):
|
||||
export(Foo(), (torch.randn(1, 33, 256, 128), k, v))
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOneOffModelExportResult(TestCase):
|
||||
|
@ -5075,6 +5075,7 @@ def scaled_dot_product_flash_attention_for_cpu(
|
||||
is_causal=is_causal,
|
||||
dropout_mask=None,
|
||||
scale=scale,
|
||||
enable_gqa=query.size(1) != key.size(1),
|
||||
)
|
||||
# Why this change?
|
||||
# In pre-dispatch export scaled_dot_product_attention is executed via
|
||||
|
Reference in New Issue
Block a user