mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ATen] Fix CUDA reduction warp shuffle order (#164790)
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164790 Approved by: https://github.com/ngimel, https://github.com/eqy ghstack dependencies: #165494
This commit is contained in:
committed by
PyTorch MergeBot
parent
7e6721fb0a
commit
36371b8ec7
@ -220,6 +220,8 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
(torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
|
||||
(torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
|
||||
(torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
|
||||
(torch.float16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
|
||||
(torch.bfloat16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
|
||||
# see https://github.com/pytorch/pytorch/pull/96264
|
||||
(torch.float16, torch.ops.aten.mv.default): 1e-5,
|
||||
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
|
||||
@ -295,6 +297,7 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
else:
|
||||
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
||||
|
||||
test_case.assertEqual(
|
||||
orig,
|
||||
decomp,
|
||||
|
Reference in New Issue
Block a user