mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Grouped Query Attention (#132689)
### Approach: Using the current function declaration **Constraint:** Q_Heads % KV_Heads == 0 **Major change:** - Added a new argument enable_gqa: bool to sdpa function call - It adds a meaning to the last third dimension. Sample use cases this would enable: LLama3 ``` # LLama3 8b call to SDPA query = torch.rand(batch, 32, seq_len_q, D) key = torch.rand(batch, 8, seq_len_kv, D) value = torch.rand(batch, 8, seq_len_kv, D) output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) # Output Shape (batch, 32, seq_len_q, D) ``` ### Design Choice: - Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0 - The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms. - By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged. ### Benchmarks: - **sdpa.py: #130634** For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False | | ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- | | 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 | | 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 | | 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 | | 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |  - **TorchTitan: https://github.com/pytorch/torchtitan/pull/458** Differential Revision: D60772086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689 Approved by: https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
527f104a69
commit
8bc5ef563e
@ -1955,16 +1955,24 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
at::Tensor const& value,
|
||||
std::optional<at::Tensor> attn_mask,
|
||||
double dropout,
|
||||
bool is_causal) {
|
||||
bool is_causal,
|
||||
bool enable_gqa) {
|
||||
return sdp::sdp_params{
|
||||
query, key, value, std::move(attn_mask), dropout, is_causal};
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
std::move(attn_mask),
|
||||
dropout,
|
||||
is_causal,
|
||||
enable_gqa};
|
||||
}))
|
||||
.def_readonly("query", &sdp::sdp_params::query)
|
||||
.def_readonly("key", &sdp::sdp_params::key)
|
||||
.def_readonly("value", &sdp::sdp_params::value)
|
||||
.def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
|
||||
.def_readonly("dropout", &sdp::sdp_params::dropout)
|
||||
.def_readonly("is_causal", &sdp::sdp_params::is_causal);
|
||||
.def_readonly("is_causal", &sdp::sdp_params::is_causal)
|
||||
.def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa);
|
||||
|
||||
py::enum_<sdp::SDPBackend>(
|
||||
py_module,
|
||||
|
Reference in New Issue
Block a user