mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
### Approach: Using the current function declaration **Constraint:** Q_Heads % KV_Heads == 0 **Major change:** - Added a new argument enable_gqa: bool to sdpa function call - It adds a meaning to the last third dimension. Sample use cases this would enable: LLama3 ``` # LLama3 8b call to SDPA query = torch.rand(batch, 32, seq_len_q, D) key = torch.rand(batch, 8, seq_len_kv, D) value = torch.rand(batch, 8, seq_len_kv, D) output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) # Output Shape (batch, 32, seq_len_q, D) ``` ### Design Choice: - Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0 - The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms. - By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged. ### Benchmarks: - **sdpa.py: #130634** For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa | batch_size | q_num_heads | kv_num_heads | q_seq_len | kv_seq_len | embed_dim | forward_time when enable_gqa=True | forward_time when enable_gqa=False | | ------------ | ------------- | -------------- | ----------- | ------------ | ----------- | ----------- | ---------------- | | 1 | 32 | 8 | 2048 | 2048 | 2048 | 100.71 | 119.70 | | 8 | 32 | 8 | 2048 | 2048 | 2048 | 539.78 | 628.83 | | 16 | 32 | 8 | 2048 | 2048 | 2048 | 1056.81 | 1225.48 | | 32 | 32 | 8 | 2048 | 2048 | 2048 | 2099.54 | 2440.45 |  - **TorchTitan: https://github.com/pytorch/torchtitan/pull/458** Differential Revision: D60772086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132689 Approved by: https://github.com/drisspg
107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import contextlib
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch._dynamo.testing import CompileCounter
|
|
from torch.backends.cuda import SDPAParams
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def allow_in_graph_sdpa_params():
|
|
global SDPAParams
|
|
try:
|
|
old = SDPAParams
|
|
SDPAParams = torch._dynamo.allow_in_graph(SDPAParams)
|
|
yield
|
|
finally:
|
|
SDPAParams = old
|
|
|
|
|
|
class TestSDPA(torch._dynamo.test_case.TestCase):
|
|
def assert_ref_equals_params(self, actual, expected):
|
|
self.assertIs(actual.query, expected.query)
|
|
self.assertIs(actual.key, expected.key)
|
|
self.assertIs(actual.value, expected.value)
|
|
self.assertIs(actual.attn_mask, expected.attn_mask)
|
|
|
|
def test_returns_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(fullgraph=True, backend=counter)
|
|
def fn(q, k, v, m):
|
|
return SDPAParams(q, k, v, m, 0.1, True, False)
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
o = fn(q, k, v, m)
|
|
self.assertTrue(isinstance(o, SDPAParams))
|
|
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_graph_break_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(backend=counter)
|
|
def fn(q, k, v, m):
|
|
z = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
torch._dynamo.graph_break()
|
|
return z, q + 1
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
o, _ = fn(q, k, v, m)
|
|
self.assertTrue(isinstance(o, SDPAParams))
|
|
self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
def test_input_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(backend=counter)
|
|
def fn(sdpap, q):
|
|
torch._dynamo.graph_break()
|
|
return sdpap, sdpap.query + q
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
s = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
o, _ = fn(s, q)
|
|
self.assertIs(o, s)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_intermediate_attr_access_SDPAParams(self):
|
|
with allow_in_graph_sdpa_params():
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(fullgraph=True, backend=counter)
|
|
def fn(q, k, v, m):
|
|
q += 1
|
|
z = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
a = z.query
|
|
return a + 1, z, q
|
|
|
|
q = torch.randn(10)
|
|
k = torch.randn(10)
|
|
v = torch.randn(10)
|
|
m = torch.randn(10)
|
|
_, o, _ = fn(q, k, v, m)
|
|
expected = SDPAParams(q, k, v, m, 0.1, True, False)
|
|
self.assert_ref_equals_params(o, expected)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|