Files
pytorch/test/dynamo/test_sdpa.py
Apurva Jain 8bc5ef563e Grouped Query Attention (#132689)
### Approach: Using the current function declaration

**Constraint:** Q_Heads % KV_Heads == 0

**Major change:**
- Added a new argument enable_gqa: bool to sdpa function call
- It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

```
# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)
```

### Design Choice:

- Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
- The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
- By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

### Benchmarks:

- **sdpa.py: #130634**
For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa

 | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True   |   forward_time when enable_gqa=False    |
| ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- |
|     1      |     32      |      8       |   2048    |    2048    |   2048    |   100.71  |  119.70  |
|     8      |     32      |      8       |   2048    |    2048    |   2048    |   539.78  |  628.83  |
|     16     |     32      |      8       |   2048    |    2048    |   2048    |   1056.81  |  1225.48  |
|     32      |     32      |      8       |   2048    |    2048    |   2048    |   2099.54  |  2440.45  |

![Screenshot 2024-07-25 at 9 07 40 PM](https://github.com/user-attachments/assets/a3e5f716-c39f-4096-9e6c-82a735e57b7b)

- **TorchTitan: https://github.com/pytorch/torchtitan/pull/458**

Differential Revision: D60772086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689
Approved by: https://github.com/drisspg
2024-08-07 05:35:36 +00:00

107 lines
3.4 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import CompileCounter
from torch.backends.cuda import SDPAParams
@contextlib.contextmanager
def allow_in_graph_sdpa_params():
global SDPAParams
try:
old = SDPAParams
SDPAParams = torch._dynamo.allow_in_graph(SDPAParams)
yield
finally:
SDPAParams = old
class TestSDPA(torch._dynamo.test_case.TestCase):
def assert_ref_equals_params(self, actual, expected):
self.assertIs(actual.query, expected.query)
self.assertIs(actual.key, expected.key)
self.assertIs(actual.value, expected.value)
self.assertIs(actual.attn_mask, expected.attn_mask)
def test_returns_SDPAParams(self):
with allow_in_graph_sdpa_params():
counter = CompileCounter()
@torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m):
return SDPAParams(q, k, v, m, 0.1, True, False)
q = torch.randn(10)
k = torch.randn(10)
v = torch.randn(10)
m = torch.randn(10)
o = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assertEqual(counter.frame_count, 1)
def test_graph_break_SDPAParams(self):
with allow_in_graph_sdpa_params():
counter = CompileCounter()
@torch.compile(backend=counter)
def fn(q, k, v, m):
z = SDPAParams(q, k, v, m, 0.1, True, False)
torch._dynamo.graph_break()
return z, q + 1
q = torch.randn(10)
k = torch.randn(10)
v = torch.randn(10)
m = torch.randn(10)
o, _ = fn(q, k, v, m)
self.assertTrue(isinstance(o, SDPAParams))
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
self.assertEqual(counter.frame_count, 2)
def test_input_SDPAParams(self):
with allow_in_graph_sdpa_params():
counter = CompileCounter()
@torch.compile(backend=counter)
def fn(sdpap, q):
torch._dynamo.graph_break()
return sdpap, sdpap.query + q
q = torch.randn(10)
k = torch.randn(10)
v = torch.randn(10)
m = torch.randn(10)
s = SDPAParams(q, k, v, m, 0.1, True, False)
o, _ = fn(s, q)
self.assertIs(o, s)
self.assertEqual(counter.frame_count, 1)
def test_intermediate_attr_access_SDPAParams(self):
with allow_in_graph_sdpa_params():
counter = CompileCounter()
@torch.compile(fullgraph=True, backend=counter)
def fn(q, k, v, m):
q += 1
z = SDPAParams(q, k, v, m, 0.1, True, False)
a = z.query
return a + 1, z, q
q = torch.randn(10)
k = torch.randn(10)
v = torch.randn(10)
m = torch.randn(10)
_, o, _ = fn(q, k, v, m)
expected = SDPAParams(q, k, v, m, 0.1, True, False)
self.assert_ref_equals_params(o, expected)
self.assertEqual(counter.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()