mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Grouped Query Attention (#128898)
### Approach: Using the current function declaration **Constraint:** Q_Heads % KV_Heads == 0 **Major change:** - Added a new argument enable_gqa: bool to sdpa function call - It adds a meaning to the last third dimension. Sample use cases this would enable: LLama3 ``` # LLama3 8b call to SDPA query = torch.rand(batch, 32, seq_len_q, D) key = torch.rand(batch, 8, seq_len_kv, D) value = torch.rand(batch, 8, seq_len_kv, D) output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) # Output Shape (batch, 32, seq_len_q, D) ``` ### Design Choice: - Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0 - The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms. - By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged. ### Benchmarks: - **sdpa.py: #130634** For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False | | ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- | | 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 | | 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 | | 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 | | 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |  - **TorchTitan: https://github.com/pytorch/torchtitan/pull/458** Pull Request resolved: https://github.com/pytorch/pytorch/pull/128898 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
05a8540041
commit
d039b14207
@ -466,6 +466,7 @@ def gen_nn_functional(fm: FileManager) -> None:
|
||||
"dropout_p: float = 0.0",
|
||||
"is_causal: bool = False",
|
||||
"scale: Optional[float] = None",
|
||||
"enable_gqa: bool = False",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
Reference in New Issue
Block a user