Grouped Query Attention (#132689)

### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
This commit is contained in:
Apurva Jain
2024-08-07 05:35:36 +00:00
committed by PyTorch MergeBot
parent 527f104a69
commit 8bc5ef563e
19 changed files with 372 additions and 170 deletions

View File

@ -175,6 +175,7 @@ class CausalBias(torch.Tensor):
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
r"""
Handles the logic for computing attention with the specified causal bias.
@ -191,6 +192,7 @@ class CausalBias(torch.Tensor):
are set.
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
@ -214,10 +216,13 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p,
is_causal=True,
scale=scale,
enable_gqa=enable_gqa,
)
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
sdpa_params = SDPAParams(
query, key, value, None, dropout_p, is_causal, enable_gqa
)
if can_use_flash_attention(sdpa_params):
needs_padding = query.size(-1) % 8 != 0
og_head_size = query.size(-1)
@ -266,6 +271,7 @@ class CausalBias(torch.Tensor):
dropout_p=dropout_p,
is_causal=False,
scale=scale,
enable_gqa=enable_gqa,
)
else:
raise ValueError(