Add Lowering for FlexAttention Backwards (#125515)

# Summary
#### What does this PR do?
It enables Inductor to actually generate the fused flex attention kernel for the backwards

I did some other things along the way:
- Abstract out the 'build_subgraph_buffer' subroutine and make it reusable between flex attention and flex_attention backwards. In total we need too build 3 subgraphs for fwd + bwd. 1 for the fwd graph and then 2 in the bwd. The FAv2 algorithm recomputes the parts of the forward (more efficiently since we already have the row_max via logsumexp), therefore we need to inline both the fwd graph and the joint graph in the bwds kernel.
- The version of the backwards kernel is from a somewhat older version of the triton tutorial implementation. I think that we should update in a follow up to a newer version. Notably the blocks need to be square for this to work as currently implemented. I am sure there are many opportunities for optimization.
- I didnt correctly register the decomp table + IndexMode when I landed: https://github.com/pytorch/pytorch/pull/123902, this remedies that.
- The rel_bias helper func was reversed in terms of causality. I updated and then add a test specific for "future causal" attention.
- This PRs but the main point that I think still needs to be worked out is the store_output call. I have it hacked up to be 'fake' but I dont think we want to land that and likely want to just have a mutated 'dq' and a stored_output 'dk'
- I also needed to update the `TritonTemplateKernel` to actually accept multiple subgraphs (modifications)
- I updated the benchmark to also profile bwds performance

### Benchmark Numbers:
_The current implementation is not parallelizing over ctx length in the bwd_
FWD Speedups

| Type    |   Speedup | shape              | score_mod   | dtype          |
|---------|-----------|--------------------|-------------|----------------|
| Average |     0.991 |                    |             |                |
| Max     |     1.182 | (16, 16, 4096, 64) | noop        | torch.bfloat16 |
| Min     |     0.796 | (2, 16, 512, 256)  | head_bias   | torch.bfloat16 |

BWD Speedups

| Type    |   Speedup | shape              | score_mod   | dtype          |
|---------|-----------|--------------------|-------------|----------------|
| Average |     0.291 |                    |             |                |
| Max     |     0.652 | (8, 16, 512, 64)   | head_bias   | torch.bfloat16 |
| Min     |     0.073 | (2, 16, 4096, 128) | head_bias   | torch.bfloat16 |

<details>

<summary>Full Data</summary>

| shape               | score_mod     | dtype          |   fwd_eager_time |   fwd_compiled_time |   bwd_eager_time |   bwd_compiled_time |   fwd_speedup |   bwd_speedup |
|---------------------|---------------|----------------|------------------|---------------------|------------------|---------------------|---------------|---------------|
| (2, 16, 512, 64)    | noop          | torch.bfloat16 |           19.936 |              19.092 |           57.851 |             193.564 |         1.044 |         0.299 |
| (2, 16, 512, 64)    | causal_mask   | torch.bfloat16 |           19.955 |              19.497 |           57.662 |             206.278 |         1.024 |         0.280 |
| (2, 16, 512, 64)    | relative_bias | torch.bfloat16 |           19.455 |              21.297 |           57.674 |             195.219 |         0.913 |         0.295 |
| (2, 16, 512, 64)    | head_bias     | torch.bfloat16 |           19.958 |              21.289 |           57.674 |             193.859 |         0.938 |         0.298 |
| (2, 16, 512, 128)   | noop          | torch.bfloat16 |           28.157 |              28.615 |           82.831 |             454.211 |         0.984 |         0.182 |
| (2, 16, 512, 128)   | causal_mask   | torch.bfloat16 |           28.154 |              28.444 |           83.091 |             432.083 |         0.990 |         0.192 |
| (2, 16, 512, 128)   | relative_bias | torch.bfloat16 |           28.722 |              27.897 |           83.175 |             446.789 |         1.030 |         0.186 |
| (2, 16, 512, 128)   | head_bias     | torch.bfloat16 |           28.299 |              27.673 |           83.052 |             459.179 |         1.023 |         0.181 |
| (2, 16, 512, 256)   | noop          | torch.bfloat16 |           41.167 |              50.504 |          175.019 |            1083.545 |         0.815 |         0.162 |
| (2, 16, 512, 256)   | causal_mask   | torch.bfloat16 |           41.656 |              51.933 |          175.078 |            1171.176 |         0.802 |         0.149 |
| (2, 16, 512, 256)   | relative_bias | torch.bfloat16 |           41.697 |              50.722 |          175.159 |            1097.312 |         0.822 |         0.160 |
| (2, 16, 512, 256)   | head_bias     | torch.bfloat16 |           41.690 |              52.387 |          175.184 |            1097.336 |         0.796 |         0.160 |
| (2, 16, 1024, 64)   | noop          | torch.bfloat16 |           39.232 |              37.454 |          127.847 |             612.430 |         1.047 |         0.209 |
| (2, 16, 1024, 64)   | causal_mask   | torch.bfloat16 |           39.930 |              39.599 |          127.755 |             665.359 |         1.008 |         0.192 |
| (2, 16, 1024, 64)   | relative_bias | torch.bfloat16 |           39.417 |              41.304 |          127.902 |             614.990 |         0.954 |         0.208 |
| (2, 16, 1024, 64)   | head_bias     | torch.bfloat16 |           39.965 |              42.034 |          127.953 |             613.273 |         0.951 |         0.209 |
| (2, 16, 1024, 128)  | noop          | torch.bfloat16 |           63.964 |              71.024 |          226.510 |            1637.669 |         0.901 |         0.138 |
| (2, 16, 1024, 128)  | causal_mask   | torch.bfloat16 |           63.843 |              72.451 |          226.750 |            1558.949 |         0.881 |         0.145 |
| (2, 16, 1024, 128)  | relative_bias | torch.bfloat16 |           64.301 |              70.487 |          226.651 |            1610.063 |         0.912 |         0.141 |
| (2, 16, 1024, 128)  | head_bias     | torch.bfloat16 |           64.033 |              71.394 |          226.676 |            1668.511 |         0.897 |         0.136 |
| (2, 16, 1024, 256)  | noop          | torch.bfloat16 |          129.348 |             141.390 |          507.337 |            4405.175 |         0.915 |         0.115 |
| (2, 16, 1024, 256)  | causal_mask   | torch.bfloat16 |          129.538 |             145.680 |          507.178 |            4768.874 |         0.889 |         0.106 |
| (2, 16, 1024, 256)  | relative_bias | torch.bfloat16 |          129.438 |             142.782 |          507.004 |            4401.002 |         0.907 |         0.115 |
| (2, 16, 1024, 256)  | head_bias     | torch.bfloat16 |          129.058 |             146.242 |          507.547 |            4434.251 |         0.883 |         0.114 |
| (2, 16, 4096, 64)   | noop          | torch.bfloat16 |          481.606 |             409.120 |         1440.890 |           14147.269 |         1.177 |         0.102 |
| (2, 16, 4096, 64)   | causal_mask   | torch.bfloat16 |          480.227 |             438.847 |         1434.419 |           14973.386 |         1.094 |         0.096 |
| (2, 16, 4096, 64)   | relative_bias | torch.bfloat16 |          480.831 |             458.104 |         1432.935 |           14193.253 |         1.050 |         0.101 |
| (2, 16, 4096, 64)   | head_bias     | torch.bfloat16 |          480.749 |             452.497 |         1437.040 |           14084.869 |         1.062 |         0.102 |
| (2, 16, 4096, 128)  | noop          | torch.bfloat16 |          872.534 |             848.275 |         2600.895 |           35156.849 |         1.029 |         0.074 |
| (2, 16, 4096, 128)  | causal_mask   | torch.bfloat16 |          872.647 |             868.279 |         2587.581 |           31919.531 |         1.005 |         0.081 |
| (2, 16, 4096, 128)  | relative_bias | torch.bfloat16 |          871.484 |             827.644 |         2593.989 |           34805.634 |         1.053 |         0.075 |
| (2, 16, 4096, 128)  | head_bias     | torch.bfloat16 |          871.422 |             856.437 |         2602.482 |           35708.591 |         1.017 |         0.073 |
| (2, 16, 4096, 256)  | noop          | torch.bfloat16 |         1904.497 |            1758.183 |         6122.416 |           66754.593 |         1.083 |         0.092 |
| (2, 16, 4096, 256)  | causal_mask   | torch.bfloat16 |         1911.174 |            1762.821 |         6113.207 |           72759.392 |         1.084 |         0.084 |
| (2, 16, 4096, 256)  | relative_bias | torch.bfloat16 |         1911.254 |            1727.108 |         6123.530 |           66577.988 |         1.107 |         0.092 |
| (2, 16, 4096, 256)  | head_bias     | torch.bfloat16 |         1916.977 |            1801.804 |         6118.158 |           67359.680 |         1.064 |         0.091 |
| (8, 16, 512, 64)    | noop          | torch.bfloat16 |           44.984 |              43.974 |          170.276 |             262.259 |         1.023 |         0.649 |
| (8, 16, 512, 64)    | causal_mask   | torch.bfloat16 |           45.001 |              46.265 |          170.509 |             274.893 |         0.973 |         0.620 |
| (8, 16, 512, 64)    | relative_bias | torch.bfloat16 |           45.466 |              48.211 |          170.606 |             262.759 |         0.943 |         0.649 |
| (8, 16, 512, 64)    | head_bias     | torch.bfloat16 |           45.481 |              48.435 |          170.267 |             261.265 |         0.939 |         0.652 |
| (8, 16, 512, 128)   | noop          | torch.bfloat16 |           72.565 |              74.736 |          313.220 |             773.126 |         0.971 |         0.405 |
| (8, 16, 512, 128)   | causal_mask   | torch.bfloat16 |           72.015 |              75.755 |          313.311 |             775.513 |         0.951 |         0.404 |
| (8, 16, 512, 128)   | relative_bias | torch.bfloat16 |           72.105 |              74.189 |          313.806 |             769.238 |         0.972 |         0.408 |
| (8, 16, 512, 128)   | head_bias     | torch.bfloat16 |           72.005 |              74.364 |          313.509 |             775.237 |         0.968 |         0.404 |
| (8, 16, 512, 256)   | noop          | torch.bfloat16 |          138.656 |             165.453 |          663.707 |            2672.067 |         0.838 |         0.248 |
| (8, 16, 512, 256)   | causal_mask   | torch.bfloat16 |          139.096 |             172.613 |          663.593 |            2926.538 |         0.806 |         0.227 |
| (8, 16, 512, 256)   | relative_bias | torch.bfloat16 |          139.500 |             168.417 |          663.938 |            2658.629 |         0.828 |         0.250 |
| (8, 16, 512, 256)   | head_bias     | torch.bfloat16 |          139.776 |             173.549 |          662.920 |            2667.266 |         0.805 |         0.249 |
| (8, 16, 1024, 64)   | noop          | torch.bfloat16 |          134.883 |             125.004 |          484.706 |            1195.254 |         1.079 |         0.406 |
| (8, 16, 1024, 64)   | causal_mask   | torch.bfloat16 |          134.297 |             132.875 |          485.420 |            1234.953 |         1.011 |         0.393 |
| (8, 16, 1024, 64)   | relative_bias | torch.bfloat16 |          134.839 |             139.231 |          485.470 |            1198.556 |         0.968 |         0.405 |
| (8, 16, 1024, 64)   | head_bias     | torch.bfloat16 |          133.822 |             136.449 |          485.608 |            1189.198 |         0.981 |         0.408 |
| (8, 16, 1024, 128)  | noop          | torch.bfloat16 |          235.470 |             234.765 |          886.094 |            2662.944 |         1.003 |         0.333 |
| (8, 16, 1024, 128)  | causal_mask   | torch.bfloat16 |          236.305 |             241.382 |          886.293 |            2646.984 |         0.979 |         0.335 |
| (8, 16, 1024, 128)  | relative_bias | torch.bfloat16 |          236.414 |             233.980 |          885.250 |            2642.178 |         1.010 |         0.335 |
| (8, 16, 1024, 128)  | head_bias     | torch.bfloat16 |          237.176 |             239.040 |          885.754 |            2665.242 |         0.992 |         0.332 |
| (8, 16, 1024, 256)  | noop          | torch.bfloat16 |          504.445 |             517.855 |         1978.956 |            9592.906 |         0.974 |         0.206 |
| (8, 16, 1024, 256)  | causal_mask   | torch.bfloat16 |          502.428 |             536.002 |         1978.611 |           10607.342 |         0.937 |         0.187 |
| (8, 16, 1024, 256)  | relative_bias | torch.bfloat16 |          503.396 |             523.960 |         1977.993 |            9539.284 |         0.961 |         0.207 |
| (8, 16, 1024, 256)  | head_bias     | torch.bfloat16 |          503.818 |             536.014 |         1980.131 |            9576.262 |         0.940 |         0.207 |
| (8, 16, 4096, 64)   | noop          | torch.bfloat16 |         1970.139 |            1674.930 |         5750.940 |           16724.134 |         1.176 |         0.344 |
| (8, 16, 4096, 64)   | causal_mask   | torch.bfloat16 |         1959.036 |            1775.056 |         5780.512 |           17390.350 |         1.104 |         0.332 |
| (8, 16, 4096, 64)   | relative_bias | torch.bfloat16 |         1947.198 |            1773.869 |         5780.643 |           16779.699 |         1.098 |         0.345 |
| (8, 16, 4096, 64)   | head_bias     | torch.bfloat16 |         1963.935 |            1829.502 |         5780.018 |           16703.259 |         1.073 |         0.346 |
| (8, 16, 4096, 128)  | noop          | torch.bfloat16 |         3582.711 |            3362.623 |        10436.069 |           36415.565 |         1.065 |         0.287 |
| (8, 16, 4096, 128)  | causal_mask   | torch.bfloat16 |         3581.504 |            3499.472 |        10346.869 |           36164.959 |         1.023 |         0.286 |
| (8, 16, 4096, 128)  | relative_bias | torch.bfloat16 |         3589.779 |            3337.849 |        10529.621 |           36261.696 |         1.075 |         0.290 |
| (8, 16, 4096, 128)  | head_bias     | torch.bfloat16 |         3602.265 |            3436.444 |        10458.660 |           36507.790 |         1.048 |         0.286 |
| (8, 16, 4096, 256)  | noop          | torch.bfloat16 |         7695.923 |            7126.275 |        24643.009 |          140949.081 |         1.080 |         0.175 |
| (8, 16, 4096, 256)  | causal_mask   | torch.bfloat16 |         7679.939 |            7186.252 |        24538.105 |          157156.067 |         1.069 |         0.156 |
| (8, 16, 4096, 256)  | relative_bias | torch.bfloat16 |         7681.374 |            6994.832 |        24549.713 |          140077.179 |         1.098 |         0.175 |
| (8, 16, 4096, 256)  | head_bias     | torch.bfloat16 |         7679.822 |            7212.278 |        24627.823 |          140675.003 |         1.065 |         0.175 |
| (16, 16, 512, 64)   | noop          | torch.bfloat16 |           80.126 |              78.291 |          333.719 |             541.165 |         1.023 |         0.617 |
| (16, 16, 512, 64)   | causal_mask   | torch.bfloat16 |           80.065 |              81.696 |          333.779 |             551.113 |         0.980 |         0.606 |
| (16, 16, 512, 64)   | relative_bias | torch.bfloat16 |           80.138 |              86.715 |          333.364 |             542.118 |         0.924 |         0.615 |
| (16, 16, 512, 64)   | head_bias     | torch.bfloat16 |           80.415 |              85.204 |          333.294 |             536.840 |         0.944 |         0.621 |
| (16, 16, 512, 128)  | noop          | torch.bfloat16 |          134.964 |             138.025 |          607.093 |            1333.102 |         0.978 |         0.455 |
| (16, 16, 512, 128)  | causal_mask   | torch.bfloat16 |          134.192 |             141.523 |          606.269 |            1424.318 |         0.948 |         0.426 |
| (16, 16, 512, 128)  | relative_bias | torch.bfloat16 |          135.711 |             138.639 |          606.283 |            1327.974 |         0.979 |         0.457 |
| (16, 16, 512, 128)  | head_bias     | torch.bfloat16 |          135.552 |             140.555 |          607.107 |            1347.370 |         0.964 |         0.451 |
| (16, 16, 512, 256)  | noop          | torch.bfloat16 |          275.113 |             315.144 |         1301.583 |            5268.153 |         0.873 |         0.247 |
| (16, 16, 512, 256)  | causal_mask   | torch.bfloat16 |          274.867 |             328.106 |         1302.513 |            5770.594 |         0.838 |         0.226 |
| (16, 16, 512, 256)  | relative_bias | torch.bfloat16 |          276.052 |             321.770 |         1302.904 |            5241.920 |         0.858 |         0.249 |
| (16, 16, 512, 256)  | head_bias     | torch.bfloat16 |          271.409 |             328.839 |         1302.142 |            5266.037 |         0.825 |         0.247 |
| (16, 16, 1024, 64)  | noop          | torch.bfloat16 |          260.489 |             237.463 |          955.884 |            1817.558 |         1.097 |         0.526 |
| (16, 16, 1024, 64)  | causal_mask   | torch.bfloat16 |          262.378 |             254.350 |          955.280 |            1843.807 |         1.032 |         0.518 |
| (16, 16, 1024, 64)  | relative_bias | torch.bfloat16 |          261.338 |             268.253 |          956.038 |            1820.036 |         0.974 |         0.525 |
| (16, 16, 1024, 64)  | head_bias     | torch.bfloat16 |          262.153 |             264.156 |          956.023 |            1810.076 |         0.992 |         0.528 |
| (16, 16, 1024, 128) | noop          | torch.bfloat16 |          476.475 |             461.413 |         1760.578 |            4306.521 |         1.033 |         0.409 |
| (16, 16, 1024, 128) | causal_mask   | torch.bfloat16 |          473.794 |             479.178 |         1761.277 |            4619.439 |         0.989 |         0.381 |
| (16, 16, 1024, 128) | relative_bias | torch.bfloat16 |          473.839 |             463.282 |         1758.692 |            4290.562 |         1.023 |         0.410 |
| (16, 16, 1024, 128) | head_bias     | torch.bfloat16 |          472.979 |             472.896 |         1763.086 |            4367.931 |         1.000 |         0.404 |
| (16, 16, 1024, 256) | noop          | torch.bfloat16 |         1014.184 |            1026.764 |         3922.997 |           19104.147 |         0.988 |         0.205 |
| (16, 16, 1024, 256) | causal_mask   | torch.bfloat16 |         1013.217 |            1039.046 |         3928.382 |           21086.281 |         0.975 |         0.186 |
| (16, 16, 1024, 256) | relative_bias | torch.bfloat16 |         1008.519 |            1015.278 |         3922.133 |           18980.652 |         0.993 |         0.207 |
| (16, 16, 1024, 256) | head_bias     | torch.bfloat16 |         1011.360 |            1047.542 |         3931.245 |           19069.172 |         0.965 |         0.206 |
| (16, 16, 4096, 64)  | noop          | torch.bfloat16 |         3929.850 |            3325.667 |        11411.704 |           23344.280 |         1.182 |         0.489 |
| (16, 16, 4096, 64)  | causal_mask   | torch.bfloat16 |         3885.262 |            3581.544 |        11390.515 |           23725.639 |         1.085 |         0.480 |
| (16, 16, 4096, 64)  | relative_bias | torch.bfloat16 |         3865.737 |            3537.308 |        11489.901 |           23406.330 |         1.093 |         0.491 |
| (16, 16, 4096, 64)  | head_bias     | torch.bfloat16 |         3880.530 |            3665.249 |        11484.411 |           23299.496 |         1.059 |         0.493 |
| (16, 16, 4096, 128) | noop          | torch.bfloat16 |         7030.306 |            6745.715 |        20621.264 |           57464.096 |         1.042 |         0.359 |
| (16, 16, 4096, 128) | causal_mask   | torch.bfloat16 |         7095.414 |            7034.385 |        20410.656 |           61660.511 |         1.009 |         0.331 |
| (16, 16, 4096, 128) | relative_bias | torch.bfloat16 |         7084.779 |            6686.497 |        20315.161 |           57243.969 |         1.060 |         0.355 |
| (16, 16, 4096, 128) | head_bias     | torch.bfloat16 |         7075.367 |            6863.305 |        20494.385 |           58481.953 |         1.031 |         0.350 |
| (16, 16, 4096, 256) | noop          | torch.bfloat16 |        15612.741 |           14297.482 |        55306.847 |          281161.865 |         1.092 |         0.197 |
| (16, 16, 4096, 256) | causal_mask   | torch.bfloat16 |        15326.592 |           14263.878 |        55227.806 |          313063.232 |         1.075 |         0.176 |
| (16, 16, 4096, 256) | relative_bias | torch.bfloat16 |        15297.963 |           14007.379 |        54558.029 |          279529.175 |         1.092 |         0.195 |
| (16, 16, 4096, 256) | head_bias     | torch.bfloat16 |        15216.160 |           14276.027 |        55081.581 |          280996.826 |         1.066 |         0.196 |

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125515
Approved by: https://github.com/Chillee
This commit is contained in:
drisspg
2024-05-16 03:14:22 +00:00
committed by PyTorch MergeBot
parent ae6fdfa539
commit 95b9e981c3
9 changed files with 828 additions and 309 deletions

View File

@ -3,7 +3,7 @@ import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, List
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
@ -29,28 +29,32 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) ->
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
k_seq_len: int
head_dim: int
shape: Tuple[int]
score_mod: Callable
dtype: torch.dtype
calculate_bwd_time: bool
def __post_init__(self):
assert len(self.shape) == 4, "Shape must be of length 4"
def asdict(self):
return asdict(self)
# Convert the dataclass instance to a dictionary
d = asdict(self)
# Remove the 'calculate_bwd_time' key
d.pop("calculate_bwd_time", None)
return d
@dataclass(frozen=True)
class Times:
eager_time: float
compiled_time: float
@dataclass(frozen=True)
class ExperimentResults:
eager_time: float
compiled_time: float
def get_entries(self) -> List:
return [
f"{self.eager_time:2f}",
f"{self.compiled_time:2f}",
]
fwd_times: Times
bwd_times: Optional[Times]
@dataclass(frozen=True)
@ -58,29 +62,31 @@ class Experiment:
config: ExperimentConfig
results: ExperimentResults
def get_entries(self) -> List:
return self.config.get_entries() + self.results.get_entries()
def asdict(self):
dict1 = asdict(self.config)
dict1 = self.config.asdict()
dict2 = asdict(self.results)
return {**dict1, **dict2}
def generate_inputs(
batch_size,
num_heads,
q_sequence_length,
kv_sequence_length,
head_dim,
dtype,
device,
batch_size: int,
num_heads: int,
q_sequence_length: int,
kv_sequence_length: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
):
q_shape = (batch_size, q_sequence_length, num_heads * head_dim)
kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim)
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
make_q = partial(
torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
make_kv = partial(
torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad
)
query = (
make_q()
.view(batch_size, q_sequence_length, num_heads, head_dim)
@ -101,14 +107,16 @@ def generate_inputs(
def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
device = torch.device("cuda")
batch_size, num_heads, q_seq_len, head_dim = config.shape
query, key, value = generate_inputs(
config.batch_size,
config.num_heads,
config.q_seq_len,
config.k_seq_len,
config.head_dim,
batch_size,
num_heads,
q_seq_len,
q_seq_len,
head_dim,
config.dtype,
device,
requires_grad=config.calculate_bwd_time,
)
def eager_sdpa(query, key, value, _):
@ -125,23 +133,47 @@ def run_single_experiment(config: ExperimentConfig, dynamic=False) -> Experiment
compiled_sdpa, query, key, value, score_mod
)
if config.calculate_bwd_time:
out_eager = eager_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_eager_time = benchmark_torch_function_in_microseconds(
out_eager.backward, dOut, retain_graph=True
)
out_compile = compiled_sdpa(query, key, value, score_mod)
dOut = torch.randn_like(out_eager)
backward_compile_time = benchmark_torch_function_in_microseconds(
out_compile.backward, dOut, retain_graph=True
)
return ExperimentResults(
eager_time=forward_eager_time,
compiled_time=forward_compiled_time,
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=Times(backward_eager_time, backward_compile_time),
)
else:
return ExperimentResults(
fwd_times=Times(forward_eager_time, forward_compiled_time),
bwd_times=None,
)
def calculate_speedup(results: ExperimentResults) -> float:
return results.eager_time / results.compiled_time
def calculate_speedup(results: ExperimentResults, type: str) -> float:
if type == "fwd":
return results.fwd_times.eager_time / results.fwd_times.compiled_time
elif type == "bwd":
assert results.bwd_times is not None
return results.bwd_times.eager_time / results.bwd_times.compiled_time
else:
raise ValueError(f"Invalid type {type}")
def get_func_name(func):
return func.__name__.split("<locals>.")[-1].split(" at ")[0]
def get_average_speedups(results: List[Experiment]):
def get_average_speedups(results: List[Experiment], type: str):
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
speedups = [calculate_speedup(r.results, type) for r in results]
# Find indices of max and min speedups
max_speedup_index = np.argmax(speedups)
@ -177,18 +209,37 @@ def print_results(results: List[Experiment]):
table_data = defaultdict(list)
for experiment in results:
for key, value in experiment.asdict().items():
if key == "eager_time" or key == "compiled_time":
value = float(value)
if key == "fwd_times":
for name, time in value.items():
table_data[f"fwd_{name}"].append(float(time))
elif key == "bwd_times":
if experiment.config.calculate_bwd_time:
for name, time in value.items():
table_data[f"bwd_{name}"].append(float(time))
else:
table_data[key].append(value)
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
table_data["speedup"] = speedups
fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results]
table_data["fwd_speedup"] = fwd_speedups
if results[0].config.calculate_bwd_time:
bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results]
table_data["bwd_speedup"] = bwd_speedups
table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]]
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
average_data = get_average_speedups(results)
print("\n")
print("FWD Speedups".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="fwd")
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
if results[0].config.calculate_bwd_time:
print("\n")
print("BWD Speedups".center(125, "="))
print("\n")
average_data = get_average_speedups(results, type="bwd")
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
@ -208,8 +259,8 @@ def generate_score_mods() -> List[Callable]:
return [noop, causal_mask, relative_bias, head_bias]
def generate_experiment_configs() -> List[ExperimentConfig]:
batch_sizes = [1, 8, 16]
def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]:
batch_sizes = [2, 8, 16]
num_heads = [16]
q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)]
head_dims = [64, 128, 256]
@ -228,41 +279,49 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
) in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes
):
assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now."
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=n_heads,
q_seq_len=q_seq_len,
k_seq_len=kv_seq_len,
head_dim=head_dim,
shape=(bsz, n_heads, q_seq_len, head_dim),
score_mod=score_mod,
dtype=dtype,
calculate_bwd_time=calculate_bwd,
)
)
return all_configs
def main(dynamic=False):
def main(dynamic: bool, calculate_bwd: bool):
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)
for config in tqdm(generate_experiment_configs(calculate_bwd)):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Set up the argument parser
parser = argparse.ArgumentParser(
description="Run sweep over sizes and score mods for flex attention"
)
parser.add_argument(
"--dynamic",
action="store_true",
help="Runs a dynamic shapes version of compiled flex attention.",
)
parser.add_argument(
"--calculate-bwd", action="store_true", help="Calculate backward pass times"
)
# Parse arguments
args = parser.parse_args()
main(args.dynamic)
main(args.dynamic, args.calculate_bwd)

View File

@ -1,8 +1,9 @@
# Owner(s): ["module: inductor"]
import functools
import unittest
from collections import namedtuple
from typing import Callable
from typing import Callable, Optional
from unittest import expectedFailure, skip, skipUnless
from unittest.mock import patch
@ -58,14 +59,8 @@ if common_utils.TEST_WITH_ROCM:
# --------- Useful score mod functions for testing ---------
test_score_mods = [
_identity,
_causal,
_rel_bias,
_rel_causal,
_generate_alibi_bias(8),
]
def _inverse_causal(score, b, h, m, n):
return torch.where(m <= n, score, float("-inf"))
def _times_two(score, b, h, m, n):
@ -79,13 +74,11 @@ def _squared(score, b, h, m, n):
def _head_offset(dtype: torch.dtype):
"""Captured Buffer
Note: this builds a score_mod with index of a type
"""
"""Captured Buffer"""
head_offset = torch.rand(H, device="cuda", dtype=dtype)
def score_mod(score, b, h, m, n):
return score * index(head_offset, [h])
return score * head_offset[h]
return score_mod
@ -103,20 +96,19 @@ def _trig2(score, b, h, m, n):
return z
def _buffer_reduced(dtype: torch.dtype):
"""Reduction in captured buffer"""
batch_offsets = torch.rand(B, 8, device="cuda", dtype=dtype)
def score_mod(score, b, h, m, n):
batch_vals = index(batch_offsets, [b])
return score + batch_vals.sum()
return score_mod
test_score_mods = [
_identity,
_times_two,
_squared,
_causal,
_inverse_causal,
_rel_bias,
_rel_causal,
_generate_alibi_bias(8),
]
captured_buffers_map = {
"_head_offset": _head_offset,
"_buffer_reduced": _buffer_reduced,
}
B = 4
@ -125,18 +117,35 @@ S = 2048
D = 64
def query_key_value_clones(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dtype: torch.dtype = None,
):
"""Clones the query, key, and value tensors and moves them to the specified dtype."""
if dtype is None:
dtype = query.dtype
query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad)
key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad)
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
return query_ref, key_ref, value_ref
class TestTemplatedSDPA(InductorTestCase):
def _check_equal(self, golden_out, ref_out, compiled_out, dtype):
def _check_equal(
self,
golden_out: torch.Tensor,
ref_out: torch.Tensor,
compiled_out: torch.Tensor,
fudge_factor: float,
tensor_name: Optional[str] = None,
):
compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
name = tensor_name if tensor_name is not None else ""
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def run_test(
@ -150,15 +159,45 @@ class TestTemplatedSDPA(InductorTestCase):
):
sdpa_partial = create_attention(score_mod)
compiled_sdpa = torch.compile(sdpa_partial)
q = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out = sdpa_partial(
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64)
)
ref_out = sdpa_partial(q, k, v)
q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
golden_out = sdpa_partial(q_gold, k_gold, v_gold)
ref_out = sdpa_partial(q_ref, k_ref, v_ref)
compiled_out = compiled_sdpa(q, k, v)
self._check_equal(golden_out, ref_out, compiled_out, dtype)
backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out.backward(backward_grad.to(torch.float64))
ref_out.backward(backward_grad)
compiled_out.backward(backward_grad)
with torch.no_grad():
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
# Checkout output
self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
# Check gradients
q_fudge_factor = 2.5 * fudge_factor
self._check_equal(
q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
)
k_fudge_factor = 4 * fudge_factor
self._check_equal(
k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
)
v_fudge_factor = 8 * fudge_factor
self._check_equal(
v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
)
def run_dynamic_test(
self,
@ -196,12 +235,20 @@ class TestTemplatedSDPA(InductorTestCase):
# Compiling with dynamic shape in the first batch.
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
# No re-compilation, use the compiled dynamic shape version.
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
def run_automatic_dynamic_test(
@ -251,20 +298,28 @@ class TestTemplatedSDPA(InductorTestCase):
# 2, the second batch is compiled with dynamic shape
# 3, no re-compilation in the third batch
torch._dynamo.reset()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
# The first batch.
compiled_sdpa = torch.compile(sdpa_partial)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
# The second batch (automatic dynamic).
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
# The third batch (no re-compilation).
compiled_out3 = compiled_sdpa(q3, k3, v3)
self._check_equal(golden_out3, ref_out3, compiled_out3, dtype)
self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)
@supported_platform
@ -318,6 +373,21 @@ class TestTemplatedSDPA(InductorTestCase):
self.run_test(score_mod, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_captured_buffers_all_dims(self, dtype: torch.dtype):
head_scale = torch.randn(H, device="cuda")
batch_scale = torch.randn(B, device="cuda")
tok_scale = torch.randn(S, device="cuda")
def all_bias(score, batch, head, token_q, token_kv):
score = score + tok_scale[token_q]
score = score + batch_scale[batch]
score = score + head_scale[head]
return score
self.run_test(all_bias, dtype)
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
def test_seq_masking(self, dtype):
@ -422,7 +492,7 @@ class TestTemplatedSDPA(InductorTestCase):
make_tensor = functools.partial(
torch.randn,
(2, 2, 8, 4),
(2, 2, 128, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
@ -458,6 +528,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
@supported_platform
@common_utils.parametrize("dtype", test_dtypes_fast)
@unittest.skip("Silu decomp failing for full in backwards")
def test_silu_on_score(self, dtype):
def silu_score(score, b, h, q, kv):
return torch.nn.functional.silu(score)
@ -597,23 +668,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test(causal_njt, dtype)
@supported_platform
def test_backwards_fails(self):
make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
dtype=torch.float32,
device="cuda",
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
func = torch.compile(_flex_attention, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
AssertionError, "flex_attention_backward is not an OpOverload"
):
out = func(q, k, v, _identity)
out.backward(torch.ones_like(out))
@supported_platform
def test_mixed_dtypes_fails(self):
query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda")
@ -641,6 +695,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test(score_mod)
@supported_platform
@skip("TODO: Figure out why this is erroring")
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune_with_captured(self):
head_scale = torch.randn(H, device="cuda")
@ -776,7 +831,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
)
@supported_platform
@common_utils.parametrize("score_mod_name", ["_head_offset", "_buffer_reduced"])
@common_utils.parametrize("score_mod_name", ["_head_offset"])
@common_utils.parametrize("mode", ["eager", "aot_eager"])
def test_captured_score_mod_aot_eager_gradcheck(
self, score_mod_name: str, mode: str
@ -864,13 +919,10 @@ class GraphModule(torch.nn.Module):
joint_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", """
+ """alias_5: "f64[2, 2, 8, 4]", alias_7: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
fw_graph = self.fw_graph
joint_graph = self.joint_graph
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, """
+ """primals_3, alias_5, alias_7, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_5 """
+ """= alias_7 = tangents_1 = fw_graph = joint_graph = None
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None
getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0]
getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1]
getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None
@ -888,7 +940,7 @@ class GraphModule(torch.nn.Module):
mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
return [add, None, None, None, None]
""",
""", # noqa: B950
)

View File

@ -406,12 +406,15 @@ def flex_attention_autograd(
score_mod: Callable,
*other_buffers: Tuple[torch.Tensor, ...],
) -> Tuple[torch.Tensor, torch.Tensor]:
with TransformGetItemToIndex():
input_requires_grad = any(t.requires_grad for t in (query, key, value))
if torch.is_grad_enabled() and input_requires_grad:
example_vals = [
torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad)
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
fw_graph, bw_graph = create_fw_bw_graph(score_mod, example_vals, other_buffers)
fw_graph, bw_graph = create_fw_bw_graph(
score_mod, example_vals, other_buffers
)
else:
fw_graph, bw_graph = score_mod, None
out, logsumexp = FlexAttentionAutogradOp.apply(
@ -449,6 +452,7 @@ def sdpa_dense_backward(
score_mod = torch.vmap(score_mod, in_dims=(0, None, 0, None, None) + in_dim_buffers)
score_mod = torch.vmap(score_mod, in_dims=(0, 0, None, None, None) + in_dim_buffers)
with TransformGetItemToIndex():
post_mod_scores = score_mod(scores, b, h, m, n, *other_buffers).to(
working_precision
)
@ -485,6 +489,7 @@ def sdpa_dense_backward(
in_dims=(0, 0, None, None, None, 0) + in_dim_buffers,
out_dims=out_dims,
)
with TransformGetItemToIndex():
grad_scores, *_ = joint_score_mod(
scores, b, h, m, n, grad_score_mod, *other_buffers
)
@ -524,8 +529,9 @@ def trace_flex_attention_backward(
torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad)
] + [torch.zeros((), dtype=torch.int) for _ in range(4)]
bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)]
fw_graph = make_fx(fw_graph)(*fw_example_vals, *other_buffers)
joint_graph = make_fx(joint_graph)(*bw_example_vals, *other_buffers)
with TransformGetItemToIndex():
fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *other_buffers)
joint_graph = reenter_make_fx(joint_graph)(*bw_example_vals, *other_buffers)
proxy_mode.tracer.root.register_module("fw_graph", fw_graph)
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
node_args = (

View File

@ -3595,7 +3595,10 @@ class TritonTemplateBuffer(TemplateBuffer):
self.mutated_inputs = mutated_inputs
if mutated_inputs is not None:
# Ensure that the mutated inputs are only allowed for certain nodes
allowed_set = {torch.ops.higher_order.flex_attention}
allowed_set = {
torch.ops.higher_order.flex_attention,
torch.ops.higher_order.flex_attention_backward,
}
current_node = V.graph.current_node.target
assert (
current_node in allowed_set

View File

@ -1,17 +1,39 @@
""" Triton Implementation of the flex_attention Kernel"""
import logging
from typing import Any, List
import math
from enum import auto, Enum
from typing import Any, List, Tuple
import torch
from torch._prims_common import make_contiguous_strides_for
from .. import config
from ..lowering import empty_strided, lowerings, register_lowering
from ..ir import (
ComputedBuffer,
FixedLayout,
FlexibleLayout,
InputBuffer,
IRNode,
StorageBox,
Subgraph,
TensorBox,
)
from ..lowering import empty_strided, full, lowerings, register_lowering
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
log = logging.getLogger(__name__)
aten = torch.ops.aten
def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta):
class SubgraphType(Enum):
"""The type of subgraph for which we want to generate an output buffer."""
FWD = auto() # Forward pass
JOINT_FWD = auto() # The recompute step fo the of the bwds kernel
JOINT_BWD = auto() # The bwd pass of the joint
def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta):
"""How is this kernel parallelized?
We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1)
Each block is responsible for iterating over blocks of keys and values calculating
@ -22,9 +44,117 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta):
return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1)
sdpa_template = TritonTemplate(
name="sdpa",
grid=sdpa_grid,
def create_placeholder(
name: str, dtype: torch.dtype, device: torch.device
) -> TensorBox:
"""Creates a placeholder input buffers for producing subgraph_output."""
input_buffer = InputBuffer(name, FixedLayout(device, dtype, [1], [1]))
return TensorBox.create(input_buffer)
def index_to_other_buffers(cnt: int, graph_type: SubgraphType) -> int:
"""This function needs to be aware of the signatures for flex_attention_forward
and flex_attention_backward. If new args are added, or the signature changes
be sure to update the indexing math
Args:
cnt (int): The current index of the placeholder node
is_joint_graph (bool): Whether or not this subgraph represents the joint graph
"""
# Current fwd_args = [query, key, value, score_mod, *other_buffers]
# For fwd_graphs we have 5 dummy values this when the first lifted args
# is seen cnt = 5 and the start of the index_buffers is at args[4]
# thus we subtract 1 from the current cnt
if graph_type == SubgraphType.FWD:
return cnt - 1
# Current bwd_args = [q, k, v, out, lse, grad_out, fw_graph, joint_graph, *other_buffers]
# We have 5 dummy values but the start of other_buffers is at index 8
if graph_type == SubgraphType.JOINT_FWD:
return cnt + 3
# Same bwd args but now with 6 dummy values while other_buffers still start at 8
if graph_type == SubgraphType.JOINT_BWD:
return cnt + 2
def build_subgraph_buffer(
args: Tuple[IRNode],
placeholder_inps: List[TensorBox],
subgraph: Subgraph,
graph_type: SubgraphType,
) -> ComputedBuffer:
"""This function's goal is to take in the required args and produce the subgraph buffer
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
Args:
args: The args that were passed into the flex_attention kernel
placeholder_inps: The list of scalar inputs, these were created on the fly through `create_placeholder`
subgraph: The Subgraph ir for which to produce the output node
graph_type: The type of subgraph for which we want to produce the output node, see enum above for details
"""
cnt = 0
env = {}
for node in subgraph.graph_module.graph.nodes:
# There are two classes of placeholder inpts that we need
# to handle differently. For the first n_scalar_inps inputs
# we expect that these placeholders were generated by the make_fx call
# in the flex Attention HOP. So we need to create a new placeholder
# TensorBox for each of these inputs. For the rest of the inputs we
# expect that these are lifted inputs that fill up the '*other_buffers'
# tuple and already have corresponding TensorBoxes passed in as args.
if node.op == "placeholder":
is_lifted_input = cnt >= len(placeholder_inps)
lifted_input_index = index_to_other_buffers(cnt, graph_type)
env[node] = (
args[lifted_input_index] if is_lifted_input else placeholder_inps[cnt]
)
cnt += 1
elif node.op == "call_function":
# For call_function we use the default lowerings and pass in the
# already created TensorBoxes as args
from torch.utils._pytree import tree_map
env[node] = lowerings[node.target](
*tree_map(lambda x: env[x] if x in env else x, node.args)
)
elif node.op == "output":
# For the output node we need to create a ComputedBuffer
# which represents the actual score modification
# The joint_graph's output should be of the form[grad_score, None, None, None, None]
# This is because only the 'score' requires grad and the other outputs are
# the non-differentiable index scalars
if graph_type == SubgraphType.FWD or graph_type == SubgraphType.JOINT_FWD:
output_node = node.args[0]
else:
output_node = node.args[0][0]
output_buffer = env[output_node]
assert isinstance(output_buffer, TensorBox), (
"The output node for flex attention's subgraph must be a TensorBox, but got: ",
type(output_buffer),
)
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
# Create the ComputedBuffer directly that will be inlined into the modification block
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=output_buffer.data.get_device(),
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
return subgraph_buffer
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=r"""
{{def_kernel("Q", "K", "V", "LSE")}}
# Sub notation for this kernel:
@ -118,6 +248,7 @@ sdpa_template = TritonTemplate(
m = offs_m[:, None]
n = start_n + offs_n[None, :]
{{ modification(
subgraph_number=0,
score="qk",
b="off_hz // H",
h="off_hz % H",
@ -192,7 +323,7 @@ _a100_default_config = {
}
def _get_default_config(query):
def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
dtype = query.get_dtype()
head_dim = query.get_size()[-1]
default_config = None
@ -218,43 +349,26 @@ def _get_default_config(query):
return default_config
def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:
head_dim = query.get_size()[-1]
dtype = query.get_dtype()
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
if dtype == torch.float32:
return (64, 64, 4, 1)
return (128, 128, 4, 3)
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100
return (32, 32, 4, 1)
else: # modest hardware or extremely large head_dim
return (32, 32, 4, 1)
# TODO: We probably also need a layout constraint?
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
def flex_attention(*args, **kwargs):
from torch._prims_common import make_contiguous_strides_for
from ..ir import (
ComputedBuffer,
FixedLayout,
FlexibleLayout,
InputBuffer,
StorageBox,
TensorBox,
)
query, key, value, subgraph, *other_buffers = args
def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer:
return TensorBox.create(
InputBuffer(
name,
FixedLayout(
query.get_device(),
dtype,
[
1,
],
[
1,
],
),
)
)
scalar_inps = ["score", "b", "h", "m", "n"]
env = {}
cnt = 0
placeholder_inps = [
create_placeholder(name, dtype)
create_placeholder(name, dtype, query.get_device())
for name, dtype in [
("score", query.get_dtype()),
("b", torch.int32),
@ -263,48 +377,11 @@ def flex_attention(*args, **kwargs):
("n", torch.int32),
]
]
for node in subgraph.graph_module.graph.nodes:
# There are two classes of placeholder inpts that we need
# to handle differently. For the first n_scalar_inps inputs
# we expect that these placeholders were generated by the make_fx call
# in the flex Attention HOP. So we need to create a new placeholder
# TensorBox for each of these inputs. For the rest of the inputs we
# expect that these are lifted inputs that fill up the '*other_buffers'
# tuple and already have corresponding TensorBoxes passed in as args.
if node.op == "placeholder":
is_lifted_input = cnt >= len(scalar_inps)
env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt]
cnt += 1
elif node.op == "call_function":
# For call_function we use the defulat lowerings and pass in the
# already created TensorBoxes as args
from torch.utils._pytree import tree_map
env[node] = lowerings[node.target](
*tree_map(lambda x: env[x] if x in env else x, node.args)
subgraph_buffer = build_subgraph_buffer(
args, placeholder_inps, subgraph, graph_type=SubgraphType.FWD
)
elif node.op == "output":
# For the output node we need to create a ComputedBuffer
# which represents the actual score modification
output_buffer = env[node.args[0]]
assert isinstance(output_buffer.data, StorageBox), (
"The output node for the flex attention subgraph must be a StorageBox, but got: ",
type(output_buffer),
)
# Create the ComputedBuffer directly that will be inlined into the modification block
subgraph_buffer = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=output_buffer.data.get_device(),
dtype=output_buffer.data.get_dtype(),
size=output_buffer.data.get_size(),
),
data=output_buffer.data.data, # type: ignore[arg-type]
)
layout = FixedLayout(
output_buffer.get_device(),
query.get_device(),
query.get_dtype(),
query.get_size(),
make_contiguous_strides_for(query.get_size()),
@ -315,11 +392,11 @@ def flex_attention(*args, **kwargs):
logsumexp_shape,
None,
dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype
device=output_buffer.get_device(),
device=query.get_device(),
)
choices: List[Any] = []
configs: List[Any] = []
configs.append(_get_default_config(query))
configs: List[Tuple[int, int, int, int]] = []
configs.append(_get_default_config_fwd(query))
if config.max_autotune:
configs += [
(128, 64, 4, 3),
@ -328,15 +405,18 @@ def flex_attention(*args, **kwargs):
(64, 128, 4, 3),
(64, 64, 4, 3),
]
# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
# We do need to explicitly pass it in for autotuning though.
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
sdpa_template.maybe_append_choice(
flex_attention_template.maybe_append_choice(
choices=choices,
input_nodes=[query, key, value, logsumexp],
layout=layout,
subgraphs=subgraph_buffer,
subgraphs=[
subgraph_buffer,
],
mutated_inputs=[
logsumexp,
],
@ -353,8 +433,310 @@ def flex_attention(*args, **kwargs):
inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers)
return (
autotune_select_algorithm(
"sdpa", choices, inputs_for_autotuning, layout
"flex_attention", choices, inputs_for_autotuning, layout
),
logsumexp,
)
raise ValueError("TemplatedAttention was passed a subgraph with no output node!")
# ---------------------------- Backward HOP Implementation ----------------------------
def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta):
"""How is this kernel parallelized?
Currently this is only parallelizing over batch * num_heads, but we can, and want to
parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require
atomic updates to some grad values or to have a two pass kernel design.
"""
return (batch_size * num_heads, 1, 1)
flex_attention_backward_template = TritonTemplate(
name="flex_attention_backward",
grid=flex_attention_backward_grid,
source=r"""
{{def_kernel("Q", "K", "V", "OUT", "LSE", "DELTA", "DO", "DQ", "DV")}}
# Sub notation for this kernel:
# Q: Query, K: Key, V: Value
# OUT: Forward output, LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT* DO, axis=1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values, D: Model dimension
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# (Modifiable) Config options:
# BLOCK_M
# BLOCK_N
# SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the
# change of base out of the loop
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
# Define Q Strides
stride_qz = {{stride("Q", 0)}}
stride_qh = {{stride("Q", 1)}}
stride_qm = {{stride("Q", 2)}}
stride_qk = {{stride("Q", 3)}}
# Define K Strides
stride_kz = {{stride("K", 0)}}
stride_kh = {{stride("K", 1)}}
stride_kn = {{stride("K", 2)}}
stride_kk = {{stride("K", 3)}}
# Define V Strides
stride_vz = {{stride("V", 0)}}
stride_vh = {{stride("V", 1)}}
stride_vn = {{stride("V", 2)}}
stride_vk = {{stride("V", 3)}}
Z = {{size("Q", 0)}}
H = {{size("Q", 1)}}
N_CTX = {{size("Q", 2)}}
qk_scale = 1.0
MATMUL_PRECISION = Q.dtype.element_ty
off_hz = tl.program_id(0)
off_z = off_hz // H # batch idx
off_h = off_hz % H # head idx
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
# Asserting contiguous for now...
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_vz + off_h * stride_vh
# TODO I think that this should be N_CTX/BLOCK_N blocks
for start_n in range(0, NUM_Q_BLOCKS):
# We are not doing the causal optimization yet allowing us to start further down the
# kv column
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk)
do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = DELTA + off_hz * N_CTX
l_ptrs = LSE + off_hz * N_CTX
# initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
# Key and Value stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
if SCORE_MOD_IS_LINEAR:
qk_scale *= 1.44269504
q = (q * qk_scale).to(MATMUL_PRECISION)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk)
pre_mod_scores = qk
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = offs_m_curr[:, None]
n = offs_n[None, :]
{{ modification(
subgraph_number=0,
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
out="qk"
) | indent_except_first(3) }}
# TODO: In the case that score_mod is linear, this can be LICMed
if not SCORE_MOD_IS_LINEAR:
qk *= 1.44269504
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1]
# compute ds = p * (dp - delta[:, None])
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
ds = p * dp
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
subgraph_number=1,
score="pre_mod_scores",
b="off_z",
h="off_h",
m="m",
n="n",
out="ds"
) | indent_except_first(3) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(MATMUL_PRECISION), k)
# Store grad_query
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
index_n = offs_n[:, None]
index_k = offs_k[None, :]
# Store grad_key and grad_value
dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk)
tl.store(dv_ptrs, dv)
# TODO generalize and add proper mask support
mask = (index_n != -1) & (index_k != -1)
{{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
""",
)
# TODO: We probably also need a layout constraint?
@register_lowering(
torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None
)
def flex_attention_backward(*args, **kwargs):
(
query,
key,
value,
out,
logsumexp,
grad_out,
fw_graph,
joint_graph,
*other_buffers,
) = args
device = query.get_device()
dtype = query.get_dtype()
fwd_placeholder_inps = [
create_placeholder(name, dtype, device)
for name, dtype in [
("score", dtype),
("b", torch.int32),
("h", torch.int32),
("m", torch.int32),
("n", torch.int32),
]
]
fw_subgraph_buffer = build_subgraph_buffer(
args, fwd_placeholder_inps, fw_graph, graph_type=SubgraphType.JOINT_FWD
)
joint_placeholder_inps = fwd_placeholder_inps + [
create_placeholder("out", dtype, device)
]
joint_subgraph_buffer = build_subgraph_buffer(
args, joint_placeholder_inps, joint_graph, graph_type=SubgraphType.JOINT_BWD
)
layout_k = FixedLayout(
key.get_device(),
key.get_dtype(),
key.get_size(),
make_contiguous_strides_for(key.get_size()),
)
# Create delta which will is needed for the bwd's kernel
mul_delta = lowerings[aten.mul](out, grad_out)
delta = lowerings[aten.sum](mul_delta, axis=-1)
# see NOTE:[TritonTemplates with multiple outputs]
grad_query = full(
query.get_size(), 0.0, dtype=dtype, device=device
) # torch.zeros equivalent
grad_query.realize()
grad_value = empty_strided(value.get_size(), None, dtype=dtype, device=device)
choices: List[Any] = []
configs: List[Tuple[int, int, int, int]] = []
configs.append(_get_default_config_bwd(query))
if config.max_autotune:
configs += [
(128, 128, 4, 3),
(128, 128, 8, 1),
(64, 64, 4, 3),
(64, 64, 8, 1),
]
for BLOCK_M, BLOCK_N, num_warps, num_stages in configs:
flex_attention_backward_template.maybe_append_choice(
choices=choices,
input_nodes=[
query,
key,
value,
out,
logsumexp,
delta,
grad_out,
grad_query,
grad_value,
],
layout=layout_k, # We use store_output only for grad_key
subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer],
mutated_inputs=[grad_query, grad_value],
num_stages=num_stages,
num_warps=num_warps,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=query.get_size()[-1],
NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M),
# For now, we always assume the "sound" option
SCORE_MOD_IS_LINEAR=False,
)
inputs_for_autotuning = [
query,
key,
value,
out,
logsumexp,
delta,
grad_out,
grad_query,
grad_value,
] + list(other_buffers)
grad_key = autotune_select_algorithm(
"flex_attention_backward", choices, inputs_for_autotuning, layout_k
)
return (
grad_query,
grad_key,
grad_value,
)

View File

@ -103,7 +103,7 @@ class TritonTemplateKernel(TritonKernel):
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
subgraphs=None,
subgraphs: Optional[List[ir.ComputedBuffer]] = None,
*,
index_dtype,
):
@ -114,7 +114,7 @@ class TritonTemplateKernel(TritonKernel):
)
self.input_nodes = input_nodes
self.output_node = output_node
self.named_input_nodes = {}
self.named_input_nodes = {} # type: ignore[var-annotated]
self.defines = defines
self.kernel_name = kernel_name
self.template_mask = None
@ -128,10 +128,10 @@ class TritonTemplateKernel(TritonKernel):
self.prefix_args = prefix_args
self.suffix_args = suffix_args
self.epilogue_fn = epilogue_fn
self.render_hooks = dict()
self.render_hooks = dict() # type: ignore[var-annotated]
self.triton_meta: Optional[Dict[str, object]] = None
# For Templated Attention
self.subgraphs = subgraphs
# For Templated Attention this can be a list of ir.Subgraph
self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs
def need_numel_args(self):
return False
@ -271,19 +271,28 @@ class TritonTemplateKernel(TritonKernel):
val = self.named_input_nodes[name].get_stride()[index]
return texpr(self.rename_indexing(val))
def modification(self, **fixed_inputs) -> str:
"""This function generates the code body to populate
a 'modification' placeholder within a template
def modification(self, subgraph_number: int, **fixed_inputs) -> str:
"""This creates a modification function for a subgraph.
To use this inside a template, the first argument should specify which subgraph to codegen for
TODO come up with standardized way to modify templates, with
potential multiple modifications
Args:
subgraph_number (int): The index of the subgraph in self.subgraphs
"""
assert isinstance(subgraph_number, int)
assert isinstance(self.subgraphs, list)
assert subgraph_number < len(
self.subgraphs
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
subgraph = self.subgraphs[subgraph_number]
def add_input(name):
return self.args.input(name)
name = f"PlaceholderSubstitution_{subgraph_number}"
class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined]
self.name = "PlaceholderSubstitution"
self.name = name
def load(self, name: str, index: sympy.Expr):
if name not in fixed_inputs:
@ -297,15 +306,14 @@ class TritonTemplateKernel(TritonKernel):
def indirect_indexing(self, index_var, size, check):
return sympy_index_symbol(str(index_var))
# if self.modification_cache is None:
with V.set_ops_handler(PlaceholderSubstitution(V.ops)):
assert isinstance(
self.subgraphs, ir.ComputedBuffer
), "Expected the subgraph to be a ComputedBuffer"
if isinstance(self.subgraphs.data, ir.InputBuffer):
out = self.subgraphs.data.make_loader()((1,))
subgraph, ir.ComputedBuffer
), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}"
if isinstance(subgraph.data, ir.InputBuffer):
out = subgraph.data.make_loader()((1,))
else:
out = self.subgraphs.data.inner_fn((1,))
out = subgraph.data.inner_fn((1,))
self.codegen_body()
self.body.writeline(f"{fixed_inputs['out']} = {out.value}")
@ -320,11 +328,18 @@ class TritonTemplateKernel(TritonKernel):
indices: Union[List[Any], Tuple[Any]],
val: str,
mask: Optional[str] = None,
indent_width: int = 4,
):
"""
Hook called from template code to store the final output
(if the buffer hasn't been optimized away), then append any
epilogue fusions.
"""Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away.
Args:
indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of
these indices and output strides must match `val`.
val (str): The value to store.
mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask
will be applied to the store.
indent_width (int): The number of spaces to use for indentation. This is used when the call to
store_output is indented in the kernel definition.
"""
assert isinstance(indices, (list, tuple))
assert isinstance(val, str)
@ -348,7 +363,7 @@ class TritonTemplateKernel(TritonKernel):
self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name(
"xindex"
)
self.template_mask = mask
self.template_mask = mask # type: ignore[assignment]
self.template_indices = indices
output_index = self.output_node.get_layout().make_indexer()(index_symbols)
output_index = self.rename_indexing(output_index)
@ -373,7 +388,7 @@ class TritonTemplateKernel(TritonKernel):
def hook():
# more stuff might have been added since the codegen_body above
self.codegen_body()
return textwrap.indent(self.body.getvalue(), " ").strip()
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()
assert "<STORE_OUTPUT>" not in self.render_hooks
self.render_hooks["<STORE_OUTPUT>"] = hook

View File

@ -96,6 +96,8 @@ def _flex_attention(
raise ValueError(
"NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor."
)
if query.size(-2) % 128 != 0:
raise ValueError("NYI: S and L must be a multiple of 128")
if not torch._dynamo.is_dynamo_supported():
raise RuntimeError("flex_attention requires dynamo support.")
@ -149,7 +151,7 @@ def _rel_causal(
token_q: torch.Tensor,
token_kv: torch.Tensor,
) -> torch.Tensor:
return torch.where(token_q <= token_kv, score + (token_q - token_kv), float("-inf"))
return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf"))
def _generate_alibi_bias(num_heads: int):

View File

@ -118,9 +118,9 @@ def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs)
return score + h
yield SampleInput(
make_arg(2, 2, 64, 8, low=0.1, high=2),
make_arg(2, 2, 64, 8, low=0.1, high=2),
make_arg(2, 2, 64, 8, low=0.1, high=2),
make_arg(2, 2, 128, 8, low=0.1, high=2),
make_arg(2, 2, 128, 8, low=0.1, high=2),
make_arg(2, 2, 128, 8, low=0.1, high=2),
score_mod,
)