Commit Graph

1 Commits

Author SHA1 Message Date
6b120c6cf9 Update the sdpa benchmark to measure forward backward time in isolation (#115986)
# Summary

The benchmarks were getting a little stale and I think it makes more sense to measure in isolation now rather than E2E in a mha component.

This is a pre-req for getting the data for https://github.com/pytorch/pytorch/pull/115357

Output from run:
``` Shell
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
| batch_size | num_heads | q_seq_len | kv_seq_len | embed_dim | is_causal |     dtype      |    forward_time    |   backward_time    |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
|     1      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 23.86634959839284  | 66.21150835417211  |
|     1      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 23.452017060481012 | 66.90612225793302  |
|     1      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 24.478124547749758 |  76.4232068322599  |
|     1      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 |  24.6928428998217  | 75.76151192188263  |
|     1      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 28.69622849393636  | 114.73898496478796 |
|     1      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 34.399422979913645 | 112.96746158041059 |
|     1      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 |  65.4690912924707  | 216.26344555988908 |
|     1      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 88.57532404363155  | 212.07790216431025 |
|     8      |    16     |    128    |    128     |   2048    |   True    | torch.bfloat16 | 11.582905380055308 | 70.09557797573505  |
|     8      |    16     |    128    |    128     |   2048    |   False   | torch.bfloat16 | 12.068384909071026 | 70.01491216942668  |
|     8      |    16     |    256    |    256     |   2048    |   True    | torch.bfloat16 | 31.671419646590945 | 203.54910241439939 |
|     8      |    16     |    256    |    256     |   2048    |   False   | torch.bfloat16 |  33.0585768679157  | 209.45609430782497 |
|     8      |    16     |    512    |    512     |   2048    |   True    | torch.bfloat16 | 87.43969700299202  | 469.8729298543185  |
|     8      |    16     |    512    |    512     |   2048    |   False   | torch.bfloat16 | 123.9265550393611  | 580.1084265112877  |
|     8      |    16     |   1024    |    1024    |   2048    |   True    | torch.bfloat16 | 561.1918237991632  | 1181.655174586922  |
|     8      |    16     |   1024    |    1024    |   2048    |   False   | torch.bfloat16 | 884.2707145959139  | 1662.4679416418073 |
+------------+-----------+-----------+------------+-----------+-----------+----------------+--------------------+--------------------+
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115986
Approved by: https://github.com/mikaylagawarecki
2023-12-18 22:40:47 +00:00