mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] set enable_gqa in export flash->math decomp (#158604)
Differential Revision: D78524147 For `scaled_dot_product_attention(..., enable_gqa=True)`: - the Math backend passes the flag through, performing the extra [KV broadcast](6e07d6a0ff/aten/src/ATen/native/transformers/attention.cpp (L902)
) if set to True - the Flash backend has no flag, and relies on correct indexing in the C++ kernel - Export used to default to Math for `enable_gqa=True`, but https://github.com/pytorch/pytorch/pull/157893 landed and enabled Flash. At the same time, there's an export-only [decomp](6e07d6a0ff/torch/_decomp/decompositions.py (L4968)
) redirecting flash -> math, calling with `enable_gqa` unset, because that info isn't available. This led to https://fb.workplace.com/groups/1028545332188949/posts/1264609398582540 crashing, calling the Math non-GQA variant, with GQA inputs. This assumes GQA for seqlen mismatches in the export decomp, setting `enable_gqa = <q seqlen> != <kv seqlen>`, relying on prior backend checks to raise on invalid input shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158604 Approved by: https://github.com/angelayi, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
f55c5d085e
commit
48fe4ff247
@ -5075,6 +5075,7 @@ def scaled_dot_product_flash_attention_for_cpu(
|
||||
is_causal=is_causal,
|
||||
dropout_mask=None,
|
||||
scale=scale,
|
||||
enable_gqa=query.size(1) != key.size(1),
|
||||
)
|
||||
# Why this change?
|
||||
# In pre-dispatch export scaled_dot_product_attention is executed via
|
||||
|
Reference in New Issue
Block a user