mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Grouped Query Attention (#132689)
### Approach: Using the current function declaration **Constraint:** Q_Heads % KV_Heads == 0 **Major change:** - Added a new argument enable_gqa: bool to sdpa function call - It adds a meaning to the last third dimension. Sample use cases this would enable: LLama3 ``` # LLama3 8b call to SDPA query = torch.rand(batch, 32, seq_len_q, D) key = torch.rand(batch, 8, seq_len_kv, D) value = torch.rand(batch, 8, seq_len_kv, D) output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) # Output Shape (batch, 32, seq_len_q, D) ``` ### Design Choice: - Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0 - The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms. - By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged. ### Benchmarks: - **sdpa.py: #130634** For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False | | ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- | | 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 | | 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 | | 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 | | 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |  - **TorchTitan: https://github.com/pytorch/torchtitan/pull/458** Differential Revision: D60772086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
527f104a69
commit
8bc5ef563e
@ -175,6 +175,7 @@ class CausalBias(torch.Tensor):
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Handles the logic for computing attention with the specified causal bias.
|
||||
@ -191,6 +192,7 @@ class CausalBias(torch.Tensor):
|
||||
are set.
|
||||
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
|
||||
to :math:`\frac{1}{\sqrt{E}}`.
|
||||
enable_gqa (optional bool): If set to True, Grouped Query Attention (GQA) is enabled, by default it is set to False.
|
||||
|
||||
Returns:
|
||||
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
|
||||
@ -214,10 +216,13 @@ class CausalBias(torch.Tensor):
|
||||
dropout_p=dropout_p,
|
||||
is_causal=True,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
elif attn_mask.variant == CausalVariant.LOWER_RIGHT:
|
||||
_validate_sdpa_input(query, key, value, None, dropout_p, is_causal, scale)
|
||||
sdpa_params = SDPAParams(query, key, value, None, dropout_p, is_causal)
|
||||
sdpa_params = SDPAParams(
|
||||
query, key, value, None, dropout_p, is_causal, enable_gqa
|
||||
)
|
||||
if can_use_flash_attention(sdpa_params):
|
||||
needs_padding = query.size(-1) % 8 != 0
|
||||
og_head_size = query.size(-1)
|
||||
@ -266,6 +271,7 @@ class CausalBias(torch.Tensor):
|
||||
dropout_p=dropout_p,
|
||||
is_causal=False,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user