ce29d0d796
[ATen] Vectorize 8 elements on 16 bit data types for sum/mean ( #165055 )
...
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256) mean 0.022686 0.008263 0.015498 0.008117 +46.38% +1.80%
(256, 256) sum 0.022769 0.008269 0.015628 0.008185 +45.69% +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512) mean 0.014116 0.009545 0.012892 0.008839 +9.49% +7.99%
(512, 512) sum 0.014110 0.009892 0.012891 0.008878 +9.46% +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024) mean 0.014727 0.012642 0.014061 0.010519 +4.74% +20.18%
(1024, 1024) sum 0.014376 0.012636 0.014069 0.010595 +2.18% +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048) mean 0.018663 0.018294 0.018171 0.014678 +2.71% +24.64%
(2048, 2048) sum 0.018638 0.017931 0.018142 0.014713 +2.73% +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096) mean 0.034216 0.036953 0.033520 0.030585 +2.08% +20.82%
(4096, 4096) sum 0.034196 0.036942 0.033518 0.030676 +2.02% +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192) mean 0.087763 0.095201 0.085439 0.084960 +2.72% +12.05%
(8192, 8192) sum 0.088079 0.095592 0.085353 0.084632 +3.19% +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384) mean 0.148174 0.149705 0.146274 0.138865 +1.30% +7.81%
(8192, 16384) sum 0.147820 0.149371 0.146419 0.138752 +0.96% +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768) mean 0.266144 0.260807 0.265953 0.253330 +0.07% +2.95%
(8192, 32768) sum 0.266572 0.261163 0.265729 0.253294 +0.32% +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536) mean 0.502034 0.486312 0.498417 0.481246 +0.73% +1.05%
(8192, 65536) sum 0.501597 0.486351 0.497735 0.481579 +0.78% +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072) mean 0.971178 0.942988 0.957164 0.938316 +1.46% +0.50%
(8192, 131072) sum 0.971189 0.943232 0.956814 0.937816 +1.50% +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144) mean 1.953728 1.877648 1.904937 1.861692 +2.56% +0.86%
(8192, 262144) sum 1.953969 1.877538 1.905990 1.862547 +2.52% +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144) mean 0.970408 0.940965 0.957871 0.936732 +1.31% +0.45%
(4096, 262144) sum 0.970919 0.941652 0.957765 0.936676 +1.37% +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144) mean 0.501477 0.486976 0.497964 0.483570 +0.71% +0.70%
(2048, 262144) sum 0.501955 0.487213 0.498210 0.483218 +0.75% +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144) mean 0.266536 0.257111 0.265642 0.255439 +0.34% +0.65%
(1024, 262144) sum 0.266613 0.257096 0.265427 0.255472 +0.45% +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072) mean 0.087805 0.091200 0.085818 0.087851 +2.32% +3.81%
(512, 131072) sum 0.087788 0.091249 0.085373 0.087944 +2.83% +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000) mean 0.014503 0.012328 0.013663 0.010190 +6.15% +20.98%
(1000, 1000) sum 0.014545 0.012378 0.013662 0.010579 +6.46% +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129) mean 0.014163 0.008371 0.012893 0.008828 +9.85% -5.18%
(1024, 129) sum 0.014132 0.008751 0.013234 0.008868 +6.79% -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257) mean 0.014296 0.009101 0.013334 0.008563 +7.21% +6.28%
(1024, 257) sum 0.014302 0.009058 0.013020 0.008672 +9.85% +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587) mean 0.014127 0.010997 0.013443 0.009944 +5.09% +10.59%
(1024, 587) sum 0.014471 0.011373 0.013123 0.010354 +10.27% +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977) mean 0.015607 0.013566 0.015089 0.012152 +3.43% +11.64%
(2048, 977) sum 0.015953 0.013580 0.015039 0.011861 +6.08% +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128) mean 0.013982 0.008058 0.012747 0.008139 +9.69% -1.00%
(1024, 128) sum 0.013967 0.008071 0.012726 0.007859 +9.75% +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128) mean 0.014378 0.009627 0.013712 0.009395 +4.86% +2.47%
(8192, 128) sum 0.014389 0.009965 0.013718 0.009521 +4.89% +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130) mean 0.014156 0.008267 0.012895 0.008833 +9.78% -6.41%
(1024, 130) sum 0.013797 0.008277 0.012903 0.008512 +6.93% -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130) mean 0.014977 0.010026 0.013911 0.009876 +7.66% +1.52%
(8192, 130) sum 0.014994 0.010043 0.014235 0.009604 +5.33% +4.57%
====================================================================================================================================================================================
```
**FP16**
```
Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256) mean 0.022804 0.008298 0.015888 0.007848 +43.53% +5.73%
(256, 256) sum 0.023215 0.008328 0.015677 0.007850 +48.08% +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512) mean 0.013777 0.009988 0.012884 0.008512 +6.93% +17.34%
(512, 512) sum 0.013775 0.009622 0.012870 0.009028 +7.03% +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024) mean 0.014740 0.012322 0.013708 0.010239 +7.53% +20.34%
(1024, 1024) sum 0.014762 0.012756 0.013722 0.010307 +7.58% +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048) mean 0.018700 0.018364 0.018135 0.015078 +3.12% +21.79%
(2048, 2048) sum 0.018276 0.018415 0.018471 0.015127 -1.06% +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096) mean 0.034518 0.037000 0.033838 0.030617 +2.01% +20.85%
(4096, 4096) sum 0.034569 0.037448 0.033842 0.031100 +2.15% +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192) mean 0.087675 0.095176 0.085328 0.084105 +2.75% +13.16%
(8192, 8192) sum 0.088102 0.095211 0.085707 0.084090 +2.79% +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384) mean 0.147800 0.149263 0.146388 0.138390 +0.96% +7.86%
(8192, 16384) sum 0.148147 0.148957 0.146439 0.138801 +1.17% +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768) mean 0.266316 0.260294 0.265829 0.253411 +0.18% +2.72%
(8192, 32768) sum 0.266562 0.260717 0.265744 0.253308 +0.31% +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536) mean 0.502035 0.486077 0.498139 0.481374 +0.78% +0.98%
(8192, 65536) sum 0.501571 0.485733 0.498353 0.481350 +0.65% +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072) mean 0.971343 0.943016 0.956600 0.938622 +1.54% +0.47%
(8192, 131072) sum 0.971463 0.942991 0.957352 0.938334 +1.47% +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144) mean 1.952722 1.877165 1.906406 1.861455 +2.43% +0.84%
(8192, 262144) sum 1.952634 1.876388 1.904677 1.861282 +2.52% +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144) mean 0.970697 0.941298 0.956964 0.936160 +1.44% +0.55%
(4096, 262144) sum 0.969981 0.941078 0.957016 0.936260 +1.35% +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144) mean 0.501577 0.487208 0.498422 0.483493 +0.63% +0.77%
(2048, 262144) sum 0.502029 0.487124 0.497854 0.483643 +0.84% +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144) mean 0.266416 0.257383 0.265928 0.255140 +0.18% +0.88%
(1024, 262144) sum 0.266434 0.257081 0.265817 0.255143 +0.23% +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072) mean 0.087858 0.091296 0.085816 0.087745 +2.38% +4.05%
(512, 131072) sum 0.088144 0.091314 0.085664 0.087864 +2.90% +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000) mean 0.014977 0.012393 0.014141 0.010614 +5.91% +16.76%
(1000, 1000) sum 0.014589 0.012804 0.014118 0.010320 +3.34% +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129) mean 0.014208 0.008383 0.013273 0.008440 +7.04% -0.68%
(1024, 129) sum 0.013804 0.008863 0.013265 0.009003 +4.06% -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257) mean 0.014378 0.009109 0.013037 0.009038 +10.29% +0.79%
(1024, 257) sum 0.014387 0.009113 0.013396 0.008698 +7.40% +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587) mean 0.014207 0.011037 0.013182 0.010391 +7.78% +6.22%
(1024, 587) sum 0.014588 0.011453 0.013539 0.010049 +7.75% +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977) mean 0.016024 0.013614 0.015448 0.011845 +3.73% +14.93%
(2048, 977) sum 0.015990 0.014033 0.015406 0.012278 +3.79% +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128) mean 0.014037 0.007804 0.013143 0.008242 +6.80% -5.31%
(1024, 128) sum 0.014041 0.007847 0.012759 0.007850 +10.05% -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128) mean 0.014361 0.009644 0.014075 0.009061 +2.03% +6.43%
(8192, 128) sum 0.014366 0.010032 0.013702 0.009181 +4.85% +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130) mean 0.014226 0.008696 0.012894 0.008835 +10.33% -1.57%
(1024, 130) sum 0.013830 0.008740 0.013288 0.008989 +4.08% -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130) mean 0.015036 0.010019 0.013917 0.009538 +8.04% +5.04%
(8192, 130) sum 0.014652 0.010403 0.013900 0.009565 +5.41% +8.76%
====================================================================================================================================================================================
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165055
Approved by: https://github.com/ngimel
ghstack dependencies: #165494 , #164790
trunk/ce29d0d796df40f484884e7b8db8b60567dcd95b
viable/strict/1760723515
2025-10-17 13:39:36 +00:00
e0fe37fa68
[MPS] Move torch.cat
impl to Metal ( #165373 )
...
After this change, all of the cases tested in [this performance measurement script](10de64c5ac/cat/perf0.py
) take either roughly the same runtime or less.
Before:
```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000857 ms, 0.016098 ms, 0.05, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000858 ms, 0.014861 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000806 ms, 0.015145 ms, 0.05, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000829 ms, 0.015355 ms, 0.05, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000582 ms, 1.02, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001076 ms, 0.022387 ms, 0.05, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000708 ms, 0.022300 ms, 0.03, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000640 ms, 0.014367 ms, 0.04, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000777 ms, 0.027506 ms, 0.03, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003383 ms, 0.269277 ms, 0.01, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.526138 ms, 0.650852 ms, 0.81, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.444091 ms, 0.628630 ms, 0.71, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.011870 ms, 0.989525 ms, 2.03, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.100653 ms, 0.948178 ms, 3.27, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.112174 ms, 0.954174 ms, 3.26, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```
After:
```
idx: cpu time, mps time, speedup, op, args, kwargs
-----------------------------------------
0: 0.000790 ms, 0.013111 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': -1}
1: 0.000800 ms, 0.014419 ms, 0.06, cat, [[tensor(shape[5, 5]), tensor(shape[5, 5])]], {'dim': 1}
2: 0.000748 ms, 0.010019 ms, 0.07, cat, [[tensor(shape[10, 5]), tensor(shape[5, 5])]], {'dim': 0}
3: 0.000767 ms, 0.010063 ms, 0.08, cat, [[tensor(shape[1, 2, 3]), tensor(shape[1, 2, 3])]], {'dim': -2}
4: 0.000591 ms, 0.000591 ms, 1.00, cat, [[tensor(shape[0]), tensor(shape[0])]], {'dim': 0}
5: 0.001220 ms, 0.009763 ms, 0.12, cat, [[tensor(shape[0]), tensor(shape[5, 5])]], {'dim': 1}
6: 0.000739 ms, 0.006203 ms, 0.12, cat, [[tensor(shape[0, 5]), tensor(shape[5, 5])]], {'dim': 0}
7: 0.000647 ms, 0.009905 ms, 0.07, cat, [[tensor(shape[1]), tensor(shape[1])]], {}
8: 0.000753 ms, 0.007818 ms, 0.10, cat, [[tensor(shape[2, 2, 2, 2])], 1], {}
9: 0.003823 ms, 0.192723 ms, 0.02, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
10: 0.576564 ms, 0.733920 ms, 0.79, cat, "[[tensor(shape[3, 1, 2]), tensor(shape[3, 2, 2]), tensor(shape[3, 3, 2]), tensor(shape[3, 1, 2]), te...", {'dim': 1}
11: 0.462957 ms, 0.692799 ms, 0.67, cat, "[[tensor(shape[1, 3, 2]), tensor(shape[2, 3, 2]), tensor(shape[3, 3, 2]), tensor(shape[1, 3, 2]), te...", {'dim': 0}
12: 2.017181 ms, 0.968345 ms, 2.08, cat, [[tensor(shape[1000000, 3, 2]), tensor(shape[1000000, 3, 2])]], {'dim': 0}
13: 3.203508 ms, 0.986382 ms, 3.25, cat, [[tensor(shape[3, 1000000, 2]), tensor(shape[3, 1000000, 2])]], {'dim': 1}
14: 3.181249 ms, 1.007773 ms, 3.16, cat, [[tensor(shape[3, 2, 1000000]), tensor(shape[3, 2, 1000000])]], {'dim': 2}
```
Fixes #165350
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165373
Approved by: https://github.com/kulinseth , https://github.com/malfet
trunk/e0fe37fa687a39e42ddeeb5c03986ffd5c40e662
2025-10-17 00:03:04 +00:00