mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Lowering for FlexAttention Backwards (#125515)
# Summary #### What does this PR do? It enables Inductor to actually generate the fused flex attention kernel for the backwards I did some other things along the way: - Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel. - The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization. - I didnt correctly register the decomp table + IndexMode when I landed: https://github.com/pytorch/pytorch/pull/123902, this remedies that. - The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention. - This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk' - I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications) - I updated the benchmark to also profile bwds performance ### Benchmark Numbers: _The current implementation is not parallelizing over ctx length in the bwd_ FWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.991 | | | | | Max | 1.182 | (16, 16, 4096, 64) | noop | torch.bfloat16 | | Min | 0.796 | (2, 16, 512, 256) | head_bias | torch.bfloat16 | BWD Speedups | Type | Speedup | shape | score_mod | dtype | |---------|-----------|--------------------|-------------|----------------| | Average | 0.291 | | | | | Max | 0.652 | (8, 16, 512, 64) | head_bias | torch.bfloat16 | | Min | 0.073 | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | <details> <summary>Full Data</summary> | shape | score_mod | dtype | fwd_eager_time | fwd_compiled_time | bwd_eager_time | bwd_compiled_time | fwd_speedup | bwd_speedup | |---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------| | (2, 16, 512, 64) | noop | torch.bfloat16 | 19.936 | 19.092 | 57.851 | 193.564 | 1.044 | 0.299 | | (2, 16, 512, 64) | causal_mask | torch.bfloat16 | 19.955 | 19.497 | 57.662 | 206.278 | 1.024 | 0.280 | | (2, 16, 512, 64) | relative_bias | torch.bfloat16 | 19.455 | 21.297 | 57.674 | 195.219 | 0.913 | 0.295 | | (2, 16, 512, 64) | head_bias | torch.bfloat16 | 19.958 | 21.289 | 57.674 | 193.859 | 0.938 | 0.298 | | (2, 16, 512, 128) | noop | torch.bfloat16 | 28.157 | 28.615 | 82.831 | 454.211 | 0.984 | 0.182 | | (2, 16, 512, 128) | causal_mask | torch.bfloat16 | 28.154 | 28.444 | 83.091 | 432.083 | 0.990 | 0.192 | | (2, 16, 512, 128) | relative_bias | torch.bfloat16 | 28.722 | 27.897 | 83.175 | 446.789 | 1.030 | 0.186 | | (2, 16, 512, 128) | head_bias | torch.bfloat16 | 28.299 | 27.673 | 83.052 | 459.179 | 1.023 | 0.181 | | (2, 16, 512, 256) | noop | torch.bfloat16 | 41.167 | 50.504 | 175.019 | 1083.545 | 0.815 | 0.162 | | (2, 16, 512, 256) | causal_mask | torch.bfloat16 | 41.656 | 51.933 | 175.078 | 1171.176 | 0.802 | 0.149 | | (2, 16, 512, 256) | relative_bias | torch.bfloat16 | 41.697 | 50.722 | 175.159 | 1097.312 | 0.822 | 0.160 | | (2, 16, 512, 256) | head_bias | torch.bfloat16 | 41.690 | 52.387 | 175.184 | 1097.336 | 0.796 | 0.160 | | (2, 16, 1024, 64) | noop | torch.bfloat16 | 39.232 | 37.454 | 127.847 | 612.430 | 1.047 | 0.209 | | (2, 16, 1024, 64) | causal_mask | torch.bfloat16 | 39.930 | 39.599 | 127.755 | 665.359 | 1.008 | 0.192 | | (2, 16, 1024, 64) | relative_bias | torch.bfloat16 | 39.417 | 41.304 | 127.902 | 614.990 | 0.954 | 0.208 | | (2, 16, 1024, 64) | head_bias | torch.bfloat16 | 39.965 | 42.034 | 127.953 | 613.273 | 0.951 | 0.209 | | (2, 16, 1024, 128) | noop | torch.bfloat16 | 63.964 | 71.024 | 226.510 | 1637.669 | 0.901 | 0.138 | | (2, 16, 1024, 128) | causal_mask | torch.bfloat16 | 63.843 | 72.451 | 226.750 | 1558.949 | 0.881 | 0.145 | | (2, 16, 1024, 128) | relative_bias | torch.bfloat16 | 64.301 | 70.487 | 226.651 | 1610.063 | 0.912 | 0.141 | | (2, 16, 1024, 128) | head_bias | torch.bfloat16 | 64.033 | 71.394 | 226.676 | 1668.511 | 0.897 | 0.136 | | (2, 16, 1024, 256) | noop | torch.bfloat16 | 129.348 | 141.390 | 507.337 | 4405.175 | 0.915 | 0.115 | | (2, 16, 1024, 256) | causal_mask | torch.bfloat16 | 129.538 | 145.680 | 507.178 | 4768.874 | 0.889 | 0.106 | | (2, 16, 1024, 256) | relative_bias | torch.bfloat16 | 129.438 | 142.782 | 507.004 | 4401.002 | 0.907 | 0.115 | | (2, 16, 1024, 256) | head_bias | torch.bfloat16 | 129.058 | 146.242 | 507.547 | 4434.251 | 0.883 | 0.114 | | (2, 16, 4096, 64) | noop | torch.bfloat16 | 481.606 | 409.120 | 1440.890 | 14147.269 | 1.177 | 0.102 | | (2, 16, 4096, 64) | causal_mask | torch.bfloat16 | 480.227 | 438.847 | 1434.419 | 14973.386 | 1.094 | 0.096 | | (2, 16, 4096, 64) | relative_bias | torch.bfloat16 | 480.831 | 458.104 | 1432.935 | 14193.253 | 1.050 | 0.101 | | (2, 16, 4096, 64) | head_bias | torch.bfloat16 | 480.749 | 452.497 | 1437.040 | 14084.869 | 1.062 | 0.102 | | (2, 16, 4096, 128) | noop | torch.bfloat16 | 872.534 | 848.275 | 2600.895 | 35156.849 | 1.029 | 0.074 | | (2, 16, 4096, 128) | causal_mask | torch.bfloat16 | 872.647 | 868.279 | 2587.581 | 31919.531 | 1.005 | 0.081 | | (2, 16, 4096, 128) | relative_bias | torch.bfloat16 | 871.484 | 827.644 | 2593.989 | 34805.634 | 1.053 | 0.075 | | (2, 16, 4096, 128) | head_bias | torch.bfloat16 | 871.422 | 856.437 | 2602.482 | 35708.591 | 1.017 | 0.073 | | (2, 16, 4096, 256) | noop | torch.bfloat16 | 1904.497 | 1758.183 | 6122.416 | 66754.593 | 1.083 | 0.092 | | (2, 16, 4096, 256) | causal_mask | torch.bfloat16 | 1911.174 | 1762.821 | 6113.207 | 72759.392 | 1.084 | 0.084 | | (2, 16, 4096, 256) | relative_bias | torch.bfloat16 | 1911.254 | 1727.108 | 6123.530 | 66577.988 | 1.107 | 0.092 | | (2, 16, 4096, 256) | head_bias | torch.bfloat16 | 1916.977 | 1801.804 | 6118.158 | 67359.680 | 1.064 | 0.091 | | (8, 16, 512, 64) | noop | torch.bfloat16 | 44.984 | 43.974 | 170.276 | 262.259 | 1.023 | 0.649 | | (8, 16, 512, 64) | causal_mask | torch.bfloat16 | 45.001 | 46.265 | 170.509 | 274.893 | 0.973 | 0.620 | | (8, 16, 512, 64) | relative_bias | torch.bfloat16 | 45.466 | 48.211 | 170.606 | 262.759 | 0.943 | 0.649 | | (8, 16, 512, 64) | head_bias | torch.bfloat16 | 45.481 | 48.435 | 170.267 | 261.265 | 0.939 | 0.652 | | (8, 16, 512, 128) | noop | torch.bfloat16 | 72.565 | 74.736 | 313.220 | 773.126 | 0.971 | 0.405 | | (8, 16, 512, 128) | causal_mask | torch.bfloat16 | 72.015 | 75.755 | 313.311 | 775.513 | 0.951 | 0.404 | | (8, 16, 512, 128) | relative_bias | torch.bfloat16 | 72.105 | 74.189 | 313.806 | 769.238 | 0.972 | 0.408 | | (8, 16, 512, 128) | head_bias | torch.bfloat16 | 72.005 | 74.364 | 313.509 | 775.237 | 0.968 | 0.404 | | (8, 16, 512, 256) | noop | torch.bfloat16 | 138.656 | 165.453 | 663.707 | 2672.067 | 0.838 | 0.248 | | (8, 16, 512, 256) | causal_mask | torch.bfloat16 | 139.096 | 172.613 | 663.593 | 2926.538 | 0.806 | 0.227 | | (8, 16, 512, 256) | relative_bias | torch.bfloat16 | 139.500 | 168.417 | 663.938 | 2658.629 | 0.828 | 0.250 | | (8, 16, 512, 256) | head_bias | torch.bfloat16 | 139.776 | 173.549 | 662.920 | 2667.266 | 0.805 | 0.249 | | (8, 16, 1024, 64) | noop | torch.bfloat16 | 134.883 | 125.004 | 484.706 | 1195.254 | 1.079 | 0.406 | | (8, 16, 1024, 64) | causal_mask | torch.bfloat16 | 134.297 | 132.875 | 485.420 | 1234.953 | 1.011 | 0.393 | | (8, 16, 1024, 64) | relative_bias | torch.bfloat16 | 134.839 | 139.231 | 485.470 | 1198.556 | 0.968 | 0.405 | | (8, 16, 1024, 64) | head_bias | torch.bfloat16 | 133.822 | 136.449 | 485.608 | 1189.198 | 0.981 | 0.408 | | (8, 16, 1024, 128) | noop | torch.bfloat16 | 235.470 | 234.765 | 886.094 | 2662.944 | 1.003 | 0.333 | | (8, 16, 1024, 128) | causal_mask | torch.bfloat16 | 236.305 | 241.382 | 886.293 | 2646.984 | 0.979 | 0.335 | | (8, 16, 1024, 128) | relative_bias | torch.bfloat16 | 236.414 | 233.980 | 885.250 | 2642.178 | 1.010 | 0.335 | | (8, 16, 1024, 128) | head_bias | torch.bfloat16 | 237.176 | 239.040 | 885.754 | 2665.242 | 0.992 | 0.332 | | (8, 16, 1024, 256) | noop | torch.bfloat16 | 504.445 | 517.855 | 1978.956 | 9592.906 | 0.974 | 0.206 | | (8, 16, 1024, 256) | causal_mask | torch.bfloat16 | 502.428 | 536.002 | 1978.611 | 10607.342 | 0.937 | 0.187 | | (8, 16, 1024, 256) | relative_bias | torch.bfloat16 | 503.396 | 523.960 | 1977.993 | 9539.284 | 0.961 | 0.207 | | (8, 16, 1024, 256) | head_bias | torch.bfloat16 | 503.818 | 536.014 | 1980.131 | 9576.262 | 0.940 | 0.207 | | (8, 16, 4096, 64) | noop | torch.bfloat16 | 1970.139 | 1674.930 | 5750.940 | 16724.134 | 1.176 | 0.344 | | (8, 16, 4096, 64) | causal_mask | torch.bfloat16 | 1959.036 | 1775.056 | 5780.512 | 17390.350 | 1.104 | 0.332 | | (8, 16, 4096, 64) | relative_bias | torch.bfloat16 | 1947.198 | 1773.869 | 5780.643 | 16779.699 | 1.098 | 0.345 | | (8, 16, 4096, 64) | head_bias | torch.bfloat16 | 1963.935 | 1829.502 | 5780.018 | 16703.259 | 1.073 | 0.346 | | (8, 16, 4096, 128) | noop | torch.bfloat16 | 3582.711 | 3362.623 | 10436.069 | 36415.565 | 1.065 | 0.287 | | (8, 16, 4096, 128) | causal_mask | torch.bfloat16 | 3581.504 | 3499.472 | 10346.869 | 36164.959 | 1.023 | 0.286 | | (8, 16, 4096, 128) | relative_bias | torch.bfloat16 | 3589.779 | 3337.849 | 10529.621 | 36261.696 | 1.075 | 0.290 | | (8, 16, 4096, 128) | head_bias | torch.bfloat16 | 3602.265 | 3436.444 | 10458.660 | 36507.790 | 1.048 | 0.286 | | (8, 16, 4096, 256) | noop | torch.bfloat16 | 7695.923 | 7126.275 | 24643.009 | 140949.081 | 1.080 | 0.175 | | (8, 16, 4096, 256) | causal_mask | torch.bfloat16 | 7679.939 | 7186.252 | 24538.105 | 157156.067 | 1.069 | 0.156 | | (8, 16, 4096, 256) | relative_bias | torch.bfloat16 | 7681.374 | 6994.832 | 24549.713 | 140077.179 | 1.098 | 0.175 | | (8, 16, 4096, 256) | head_bias | torch.bfloat16 | 7679.822 | 7212.278 | 24627.823 | 140675.003 | 1.065 | 0.175 | | (16, 16, 512, 64) | noop | torch.bfloat16 | 80.126 | 78.291 | 333.719 | 541.165 | 1.023 | 0.617 | | (16, 16, 512, 64) | causal_mask | torch.bfloat16 | 80.065 | 81.696 | 333.779 | 551.113 | 0.980 | 0.606 | | (16, 16, 512, 64) | relative_bias | torch.bfloat16 | 80.138 | 86.715 | 333.364 | 542.118 | 0.924 | 0.615 | | (16, 16, 512, 64) | head_bias | torch.bfloat16 | 80.415 | 85.204 | 333.294 | 536.840 | 0.944 | 0.621 | | (16, 16, 512, 128) | noop | torch.bfloat16 | 134.964 | 138.025 | 607.093 | 1333.102 | 0.978 | 0.455 | | (16, 16, 512, 128) | causal_mask | torch.bfloat16 | 134.192 | 141.523 | 606.269 | 1424.318 | 0.948 | 0.426 | | (16, 16, 512, 128) | relative_bias | torch.bfloat16 | 135.711 | 138.639 | 606.283 | 1327.974 | 0.979 | 0.457 | | (16, 16, 512, 128) | head_bias | torch.bfloat16 | 135.552 | 140.555 | 607.107 | 1347.370 | 0.964 | 0.451 | | (16, 16, 512, 256) | noop | torch.bfloat16 | 275.113 | 315.144 | 1301.583 | 5268.153 | 0.873 | 0.247 | | (16, 16, 512, 256) | causal_mask | torch.bfloat16 | 274.867 | 328.106 | 1302.513 | 5770.594 | 0.838 | 0.226 | | (16, 16, 512, 256) | relative_bias | torch.bfloat16 | 276.052 | 321.770 | 1302.904 | 5241.920 | 0.858 | 0.249 | | (16, 16, 512, 256) | head_bias | torch.bfloat16 | 271.409 | 328.839 | 1302.142 | 5266.037 | 0.825 | 0.247 | | (16, 16, 1024, 64) | noop | torch.bfloat16 | 260.489 | 237.463 | 955.884 | 1817.558 | 1.097 | 0.526 | | (16, 16, 1024, 64) | causal_mask | torch.bfloat16 | 262.378 | 254.350 | 955.280 | 1843.807 | 1.032 | 0.518 | | (16, 16, 1024, 64) | relative_bias | torch.bfloat16 | 261.338 | 268.253 | 956.038 | 1820.036 | 0.974 | 0.525 | | (16, 16, 1024, 64) | head_bias | torch.bfloat16 | 262.153 | 264.156 | 956.023 | 1810.076 | 0.992 | 0.528 | | (16, 16, 1024, 128) | noop | torch.bfloat16 | 476.475 | 461.413 | 1760.578 | 4306.521 | 1.033 | 0.409 | | (16, 16, 1024, 128) | causal_mask | torch.bfloat16 | 473.794 | 479.178 | 1761.277 | 4619.439 | 0.989 | 0.381 | | (16, 16, 1024, 128) | relative_bias | torch.bfloat16 | 473.839 | 463.282 | 1758.692 | 4290.562 | 1.023 | 0.410 | | (16, 16, 1024, 128) | head_bias | torch.bfloat16 | 472.979 | 472.896 | 1763.086 | 4367.931 | 1.000 | 0.404 | | (16, 16, 1024, 256) | noop | torch.bfloat16 | 1014.184 | 1026.764 | 3922.997 | 19104.147 | 0.988 | 0.205 | | (16, 16, 1024, 256) | causal_mask | torch.bfloat16 | 1013.217 | 1039.046 | 3928.382 | 21086.281 | 0.975 | 0.186 | | (16, 16, 1024, 256) | relative_bias | torch.bfloat16 | 1008.519 | 1015.278 | 3922.133 | 18980.652 | 0.993 | 0.207 | | (16, 16, 1024, 256) | head_bias | torch.bfloat16 | 1011.360 | 1047.542 | 3931.245 | 19069.172 | 0.965 | 0.206 | | (16, 16, 4096, 64) | noop | torch.bfloat16 | 3929.850 | 3325.667 | 11411.704 | 23344.280 | 1.182 | 0.489 | | (16, 16, 4096, 64) | causal_mask | torch.bfloat16 | 3885.262 | 3581.544 | 11390.515 | 23725.639 | 1.085 | 0.480 | | (16, 16, 4096, 64) | relative_bias | torch.bfloat16 | 3865.737 | 3537.308 | 11489.901 | 23406.330 | 1.093 | 0.491 | | (16, 16, 4096, 64) | head_bias | torch.bfloat16 | 3880.530 | 3665.249 | 11484.411 | 23299.496 | 1.059 | 0.493 | | (16, 16, 4096, 128) | noop | torch.bfloat16 | 7030.306 | 6745.715 | 20621.264 | 57464.096 | 1.042 | 0.359 | | (16, 16, 4096, 128) | causal_mask | torch.bfloat16 | 7095.414 | 7034.385 | 20410.656 | 61660.511 | 1.009 | 0.331 | | (16, 16, 4096, 128) | relative_bias | torch.bfloat16 | 7084.779 | 6686.497 | 20315.161 | 57243.969 | 1.060 | 0.355 | | (16, 16, 4096, 128) | head_bias | torch.bfloat16 | 7075.367 | 6863.305 | 20494.385 | 58481.953 | 1.031 | 0.350 | | (16, 16, 4096, 256) | noop | torch.bfloat16 | 15612.741 | 14297.482 | 55306.847 | 281161.865 | 1.092 | 0.197 | | (16, 16, 4096, 256) | causal_mask | torch.bfloat16 | 15326.592 | 14263.878 | 55227.806 | 313063.232 | 1.075 | 0.176 | | (16, 16, 4096, 256) | relative_bias | torch.bfloat16 | 15297.963 | 14007.379 | 54558.029 | 279529.175 | 1.092 | 0.195 | | (16, 16, 4096, 256) | head_bias | torch.bfloat16 | 15216.160 | 14276.027 | 55081.581 | 280996.826 | 1.066 | 0.196 | </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/125515 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
ae6fdfa539
commit
95b9e981c3
@ -3,7 +3,7 @@ import itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -29,28 +29,32 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentConfig:
|
||||
batch_size: int
|
||||
num_heads: int
|
||||
q_seq_len: int
|
||||
k_seq_len: int
|
||||
head_dim: int
|
||||
shape: Tuple[int]
|
||||
score_mod: Callable
|
||||
dtype: torch.dtype
|
||||
calculate_bwd_time: bool
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.shape) == 4, "Shape must be of length 4"
|
||||
|
||||
def asdict(self):
|
||||
return asdict(self)
|
||||
# Convert the dataclass instance to a dictionary
|
||||
d = asdict(self)
|
||||
# Remove the 'calculate_bwd_time' key
|
||||
d.pop("calculate_bwd_time", None)
|
||||
return d
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Times:
|
||||
eager_time: float
|
||||
compiled_time: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExperimentResults:
|
||||
eager_time: float
|
||||
compiled_time: float
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return [
|
||||
f"{self.eager_time:2f}",
|
||||
f"{self.compiled_time:2f}",
|
||||
]
|
||||
fwd_times: Times
|
||||
bwd_times: Optional[Times]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -58,29 +62,31 @@ class Experiment:
|
||||
config: ExperimentConfig
|
||||
results: ExperimentResults
|
||||
|
||||
def get_entries(self) -> List:
|
||||
return self.config.get_entries() + self.results.get_entries()
|
||||
|
||||
def asdict(self):
|
||||
dict1 = asdict(self.config)
|
||||
dict1 = self.config.asdict()
|
||||
dict2 = asdict(self.results)
|
||||
return {**dict1, **dict2}
|
||||
|
||||
|
||||
def generate_inputs(
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_sequence_length,
|
||||
kv_sequence_length,
|
||||
head_dim,
|
||||
dtype,
|
||||
device,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
q_sequence_length: int,
|
||||
kv_sequence_length: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
requires_grad: bool,
|
||||
):
|
||||
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
|
||||
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
|
||||
|
||||
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
|
||||
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
|
||||
make_q = partial(
|
||||
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
make_kv = partial(
|
||||
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
|
||||
)
|
||||
query = (
|
||||
make_q()
|
||||
.view(batch_size, q_sequence_length, num_heads, head_dim)
|
||||
@ -101,14 +107,16 @@ def generate_inputs(
|
||||
|
||||
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
|
||||
device = torch.device("cuda")
|
||||
batch_size, num_heads, q_seq_len, head_dim = config.shape
|
||||
query, key, value = generate_inputs(
|
||||
config.batch_size,
|
||||
config.num_heads,
|
||||
config.q_seq_len,
|
||||
config.k_seq_len,
|
||||
config.head_dim,
|
||||
batch_size,
|
||||
num_heads,
|
||||
q_seq_len,
|
||||
q_seq_len,
|
||||
head_dim,
|
||||
config.dtype,
|
||||
device,
|
||||
requires_grad=config.calculate_bwd_time,
|
||||
)
|
||||
|
||||
def eager_sdpa(query, key, value, _):
|
||||
@ -125,23 +133,47 @@ def run_single_experiment(config: ExperimentConfig, dynamic=False) -> Experiment
|
||||
compiled_sdpa, query, key, value, score_mod
|
||||
)
|
||||
|
||||
return ExperimentResults(
|
||||
eager_time=forward_eager_time,
|
||||
compiled_time=forward_compiled_time,
|
||||
)
|
||||
if config.calculate_bwd_time:
|
||||
out_eager = eager_sdpa(query, key, value, score_mod)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_eager_time = benchmark_torch_function_in_microseconds(
|
||||
out_eager.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
out_compile = compiled_sdpa(query, key, value, score_mod)
|
||||
dOut = torch.randn_like(out_eager)
|
||||
backward_compile_time = benchmark_torch_function_in_microseconds(
|
||||
out_compile.backward, dOut, retain_graph=True
|
||||
)
|
||||
|
||||
return ExperimentResults(
|
||||
fwd_times=Times(forward_eager_time, forward_compiled_time),
|
||||
bwd_times=Times(backward_eager_time, backward_compile_time),
|
||||
)
|
||||
else:
|
||||
return ExperimentResults(
|
||||
fwd_times=Times(forward_eager_time, forward_compiled_time),
|
||||
bwd_times=None,
|
||||
)
|
||||
|
||||
|
||||
def calculate_speedup(results: ExperimentResults) -> float:
|
||||
return results.eager_time / results.compiled_time
|
||||
def calculate_speedup(results: ExperimentResults, type: str) -> float:
|
||||
if type == "fwd":
|
||||
return results.fwd_times.eager_time / results.fwd_times.compiled_time
|
||||
elif type == "bwd":
|
||||
assert results.bwd_times is not None
|
||||
return results.bwd_times.eager_time / results.bwd_times.compiled_time
|
||||
else:
|
||||
raise ValueError(f"Invalid type {type}")
|
||||
|
||||
|
||||
def get_func_name(func):
|
||||
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
|
||||
|
||||
|
||||
def get_average_speedups(results: List[Experiment]):
|
||||
def get_average_speedups(results: List[Experiment], type: str):
|
||||
# Calculate speedups
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
speedups = [calculate_speedup(r.results, type) for r in results]
|
||||
|
||||
# Find indices of max and min speedups
|
||||
max_speedup_index = np.argmax(speedups)
|
||||
@ -177,20 +209,39 @@ def print_results(results: List[Experiment]):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
for key, value in experiment.asdict().items():
|
||||
if key == "eager_time" or key == "compiled_time":
|
||||
value = float(value)
|
||||
table_data[key].append(value)
|
||||
if key == "fwd_times":
|
||||
for name, time in value.items():
|
||||
table_data[f"fwd_{name}"].append(float(time))
|
||||
elif key == "bwd_times":
|
||||
if experiment.config.calculate_bwd_time:
|
||||
for name, time in value.items():
|
||||
table_data[f"bwd_{name}"].append(float(time))
|
||||
else:
|
||||
table_data[key].append(value)
|
||||
|
||||
# Calculate speedups
|
||||
speedups = [calculate_speedup(r.results) for r in results]
|
||||
table_data["speedup"] = speedups
|
||||
fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
|
||||
table_data["fwd_speedup"] = fwd_speedups
|
||||
if results[0].config.calculate_bwd_time:
|
||||
bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
|
||||
table_data["bwd_speedup"] = bwd_speedups
|
||||
|
||||
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
|
||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
average_data = get_average_speedups(results)
|
||||
print("\n")
|
||||
print("FWD Speedups".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="fwd")
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
if results[0].config.calculate_bwd_time:
|
||||
print("\n")
|
||||
print("BWD Speedups".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="bwd")
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
|
||||
def generate_score_mods() -> List[Callable]:
|
||||
def noop(score, b, h, m, n):
|
||||
@ -208,8 +259,8 @@ def generate_score_mods() -> List[Callable]:
|
||||
return [noop, causal_mask, relative_bias, head_bias]
|
||||
|
||||
|
||||
def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
batch_sizes = [1, 8, 16]
|
||||
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
|
||||
batch_sizes = [2, 8, 16]
|
||||
num_heads = [16]
|
||||
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
|
||||
head_dims = [64, 128, 256]
|
||||
@ -228,41 +279,49 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
|
||||
) in itertools.product(
|
||||
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
|
||||
):
|
||||
assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now."
|
||||
all_configs.append(
|
||||
ExperimentConfig(
|
||||
batch_size=bsz,
|
||||
num_heads=n_heads,
|
||||
q_seq_len=q_seq_len,
|
||||
k_seq_len=kv_seq_len,
|
||||
head_dim=head_dim,
|
||||
shape=(bsz, n_heads, q_seq_len, head_dim),
|
||||
score_mod=score_mod,
|
||||
dtype=dtype,
|
||||
calculate_bwd_time=calculate_bwd,
|
||||
)
|
||||
)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def main(dynamic=False):
|
||||
def main(dynamic: bool, calculate_bwd: bool):
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(generate_experiment_configs()):
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
results.append(
|
||||
Experiment(config, run_single_experiment(config, dynamic=dynamic))
|
||||
)
|
||||
for config in tqdm(generate_experiment_configs(calculate_bwd)):
|
||||
results.append(Experiment(config, run_single_experiment(config)))
|
||||
|
||||
print_results(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
# Set up the argument parser
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run sweep over sizes and score mods for flex attention"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
action="store_true",
|
||||
help="Runs a dynamic shapes version of compiled flex attention.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
|
||||
)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
main(args.dynamic)
|
||||
|
||||
main(args.dynamic, args.calculate_bwd)
|
||||
|
||||
Reference in New Issue
Block a user