mirror of
https://github.com/vllm-project/vllm.git
synced 2025-11-05 18:25:13 +08:00
Compare commits
683 Commits
v0.8.0rc2
...
pd_schedul
| Author | SHA1 | Date | |
|---|---|---|---|
| 161010c384 | |||
| 3d3ab3689f | |||
| 686623c5e7 | |||
| aadb656562 | |||
| 87e067de41 | |||
| 26507f8973 | |||
| 9c1d5b456d | |||
| e31045f95c | |||
| aaec845f8e | |||
| 7bdfd29a35 | |||
| e78587a64c | |||
| 7eb4255628 | |||
| 6a0f547561 | |||
| 30ed81b7ca | |||
| 7a4a5de729 | |||
| c16fb5dae8 | |||
| e37073efd7 | |||
| 183dad7a85 | |||
| 3408e47159 | |||
| 0377b8310b | |||
| e4755f7fac | |||
| 92edf35826 | |||
| eb5819b2d9 | |||
| 5989f4684d | |||
| 5125d72f02 | |||
| a018e555fd | |||
| 6211b92273 | |||
| 05fcd1b430 | |||
| 7c02d6a137 | |||
| 11c3b98491 | |||
| dbe7f07001 | |||
| c69bf4ee06 | |||
| d27ea94034 | |||
| 99ed526101 | |||
| 207da28186 | |||
| 5b1aca2ae3 | |||
| d8e557b5e5 | |||
| 61a44a0b22 | |||
| a6481525b8 | |||
| 8cac35ba43 | |||
| 9dbf7a2dc1 | |||
| 607029e515 | |||
| cb072ce93b | |||
| 95aca283b4 | |||
| 2b05b8ce69 | |||
| 3c776dcefb | |||
| 2cbd4d2999 | |||
| 3092375e27 | |||
| 3cd91dc955 | |||
| 8a7368e069 | |||
| 93e561ec4d | |||
| e1b004839a | |||
| ee378f3d49 | |||
| e82ee40de3 | |||
| facbe2a114 | |||
| 7168920491 | |||
| 21378a2323 | |||
| 976711d9db | |||
| 44fa4d556c | |||
| 3ac98edcb1 | |||
| 966c742ed2 | |||
| 0d7d05f4b6 | |||
| 96bb8aa68b | |||
| 3badb0213b | |||
| fdcb850f14 | |||
| 54a66e5fee | |||
| 280d62b8a2 | |||
| 1666e66443 | |||
| 1575c1701a | |||
| 6ae996a873 | |||
| b590adfdc1 | |||
| b4fe16c75b | |||
| bc5dd4f669 | |||
| dbb036cf61 | |||
| 70e7ed841d | |||
| d06ba4ed3f | |||
| 6b40996ae8 | |||
| d2020acac7 | |||
| 1eb3c2ed48 | |||
| c64ee87267 | |||
| b1308b84a3 | |||
| 7b5ecf79bd | |||
| 9883a18859 | |||
| b3f2fddd17 | |||
| aa29841ede | |||
| 6bf27affb6 | |||
| 1dd23386ec | |||
| 7cbfc10943 | |||
| ce4ddd2d1a | |||
| e51929ebca | |||
| dc1b4a6f13 | |||
| 63d2705edb | |||
| d085a44082 | |||
| f49e5aff11 | |||
| 6c11ecf8d3 | |||
| 93e5f3c5fb | |||
| 70363bccfa | |||
| 3cdc57669f | |||
| 68bb122eb4 | |||
| d9fc8cd9da | |||
| f069f3ea74 | |||
| c5bc0e7fcc | |||
| 4a3a518722 | |||
| fbf722c6e6 | |||
| e92d7085bf | |||
| bd6028d6b0 | |||
| 802329dee9 | |||
| 41cc883c29 | |||
| 57504a4bcf | |||
| ed4792c990 | |||
| 87b836ba77 | |||
| 56c76c2e0e | |||
| c09632a66c | |||
| a3bf8d4a2b | |||
| 16eda8c43a | |||
| cd77382ac1 | |||
| 71b9cde010 | |||
| 5285589f37 | |||
| f41647ee6b | |||
| 4d022cbc75 | |||
| 70de35a881 | |||
| 34b2cf3b33 | |||
| 9e90c9f73f | |||
| e9528f6dc6 | |||
| 51baa9c333 | |||
| 35e076b3a8 | |||
| a26f59ccbc | |||
| aa3b3d76e0 | |||
| f7030df3be | |||
| 905e91e9ac | |||
| f8f9c0ba62 | |||
| dda811021a | |||
| 93195146ea | |||
| ed37599544 | |||
| 99ef59cf7f | |||
| d544d141ec | |||
| 3e397a9484 | |||
| 268c325078 | |||
| 3cc9af88ff | |||
| 7cd0bd7212 | |||
| 56d4aefa33 | |||
| dd143ef541 | |||
| daefed052c | |||
| 5fbab20e02 | |||
| e8224f3dca | |||
| 9665313c39 | |||
| 0c54fc7273 | |||
| c1b57855ec | |||
| 83b824c8b4 | |||
| 7678fcd5b6 | |||
| 8661c0241d | |||
| ce8d6b75fc | |||
| 61de3ef74b | |||
| ec1f9c8c91 | |||
| 65e09094c4 | |||
| c70cf0fe06 | |||
| a5d11a54dc | |||
| 3d4c87758e | |||
| a9bd832fc5 | |||
| 417bcefbae | |||
| baada0e737 | |||
| 82eb61dd4c | |||
| 0d4d06fe2f | |||
| 4aed0ca6a2 | |||
| 1621b25288 | |||
| a564797151 | |||
| 1da6a09274 | |||
| 1e44ffc3ff | |||
| a454748544 | |||
| 1bff42c4b7 | |||
| cb391d85dc | |||
| fee5b8d37f | |||
| b2ce859bd2 | |||
| 566f10a929 | |||
| c3b5189137 | |||
| a25866ac8d | |||
| 098900d7c2 | |||
| 98d01d3ce2 | |||
| d55244df31 | |||
| 04149cce27 | |||
| 24834f4894 | |||
| ec7da6fcf3 | |||
| 819d548e8a | |||
| 477d2a8aa2 | |||
| e484e02857 | |||
| 24f6b9a713 | |||
| 9cdde47289 | |||
| b1eb4ca152 | |||
| 87b4ac56c2 | |||
| cb84e45ac7 | |||
| 4716377fbc | |||
| 4e9cf8c1dd | |||
| 2976dc27e9 | |||
| 102bf967f0 | |||
| 1f4b09b525 | |||
| 86c3369eb8 | |||
| 2755c34a8f | |||
| db10422184 | |||
| e1a2c699dd | |||
| 0115ccd5c0 | |||
| 40b4284fe3 | |||
| 4ebc0b9640 | |||
| dc96fd54c6 | |||
| 1f5d13ab9f | |||
| 90cb44eb02 | |||
| e11880deea | |||
| 9351f91be9 | |||
| 5a1e1c8353 | |||
| 69ecaa7c79 | |||
| 7f00899ff7 | |||
| 995e3d1f41 | |||
| b4ac449a83 | |||
| 8e5314a468 | |||
| 87918e40c4 | |||
| f6b32efb7f | |||
| b99733d092 | |||
| 05a015d6a5 | |||
| ad971af8c7 | |||
| f2ebb6f541 | |||
| 1d01211264 | |||
| f94ab12f79 | |||
| a865bc1ca6 | |||
| 21802c4b6d | |||
| 652907b354 | |||
| 24f1c01e0f | |||
| fad6e2538e | |||
| 7f6d47c1a2 | |||
| 3147586ebd | |||
| ed636d99ca | |||
| 090c856d76 | |||
| ad434d4cfe | |||
| 66d433b94f | |||
| 027b204ff1 | |||
| 55dcce91df | |||
| 8017c8db7f | |||
| dc3529dbf6 | |||
| 7699258ef0 | |||
| e9ba99f296 | |||
| 7c80368710 | |||
| 95d63f38c0 | |||
| bb8dab821e | |||
| fc0f87768a | |||
| 0a57386721 | |||
| 3749e28774 | |||
| 86fc2321ff | |||
| 2549c0dfef | |||
| b10e519895 | |||
| 9bde5ba127 | |||
| 72c8f1ad04 | |||
| da224daaa9 | |||
| 3a100b9278 | |||
| 242a637aea | |||
| c2a9671510 | |||
| d5ae4f7f42 | |||
| b6c502a150 | |||
| 9ca710e525 | |||
| eb07c8cb5b | |||
| ba10801961 | |||
| 620fc2d09e | |||
| 29283eaa7e | |||
| 2fa66ef713 | |||
| 13affc432d | |||
| d8f094a92a | |||
| 97ae6d777f | |||
| 6baeee70d1 | |||
| d2517a4939 | |||
| 6342adc438 | |||
| 0adba91547 | |||
| 4285e423a6 | |||
| 63375f0cdb | |||
| 70ad3f9e98 | |||
| d6fc629f4d | |||
| af51d80fa1 | |||
| f5722a5052 | |||
| 651cf0fec1 | |||
| 4dc52e1c53 | |||
| 4708f13a9c | |||
| a6d042df0a | |||
| 40a36ccfeb | |||
| ef608c37a7 | |||
| 2386803f2a | |||
| 95862f7b4d | |||
| 230b131b54 | |||
| 0812d8dd41 | |||
| bf7e3c51ae | |||
| a35a8a8392 | |||
| 4ef0bb1fcf | |||
| fadc59c0e6 | |||
| 86cbd2eee9 | |||
| 092475f738 | |||
| dcc56d62da | |||
| f15e70d906 | |||
| b6be6f8d1e | |||
| 03a70eacaf | |||
| 45b1ff7a25 | |||
| 15ba07ef25 | |||
| d2b58ca203 | |||
| 82e7e19a6e | |||
| 421c462948 | |||
| 84884cd9ac | |||
| a43aa183dc | |||
| 463bbb1835 | |||
| 5e125e74d1 | |||
| 06f21ce7a5 | |||
| 57a810db9c | |||
| 8b664706aa | |||
| 37bfee92bf | |||
| e73ff24e31 | |||
| bd7599d34a | |||
| 01b6113659 | |||
| 1b84eff03a | |||
| 55acf86bf8 | |||
| f021b97993 | |||
| 1cab43c2d2 | |||
| 8bd651b318 | |||
| 58e234a754 | |||
| e86c414d6a | |||
| 550b2801ad | |||
| cefb9e5a28 | |||
| 98d7367b61 | |||
| 594a8b9030 | |||
| 44f990515b | |||
| 252937806c | |||
| 51826d51fa | |||
| 14e53ed11f | |||
| ddb94c2605 | |||
| 90969fb39a | |||
| 101f1481f9 | |||
| 2edc87b161 | |||
| 4203926f10 | |||
| cdb57015a7 | |||
| aa557e6422 | |||
| 0e00d40e4f | |||
| c920e01242 | |||
| 274d8e8818 | |||
| 2039c6305b | |||
| 6efb195a6e | |||
| 24b7fb455a | |||
| 58f5a59769 | |||
| db9dfcfa6a | |||
| 9ef98d527e | |||
| 93491aefc7 | |||
| 7acd539cd7 | |||
| e75a6301bd | |||
| a79cc68b3a | |||
| 7e3f7a4ee7 | |||
| 9ec8257914 | |||
| 38327cf454 | |||
| dfa82e2a3d | |||
| e59ca942f5 | |||
| a57a3044aa | |||
| 4e5a0f6ae2 | |||
| b63bd14999 | |||
| 2041c0e360 | |||
| 085cbc4f9f | |||
| 2b93162fb0 | |||
| 2e45bd29fe | |||
| 51d7c6a2b2 | |||
| f3aca1ee30 | |||
| 8dd41d6bcc | |||
| 0a298ea418 | |||
| d330558bab | |||
| 656fd72976 | |||
| 79455cf421 | |||
| 30d6a015e0 | |||
| 8af5a5c4e5 | |||
| 3a5f0afcd2 | |||
| c7e63aa4d8 | |||
| 4a9ce1784c | |||
| 7e4e709b43 | |||
| 63d8eabed0 | |||
| e830b01383 | |||
| ff6473980d | |||
| a164aea35d | |||
| a76f547e11 | |||
| b7b7676d67 | |||
| e6e3c55ef2 | |||
| f98a4920f9 | |||
| d4bfc23ef0 | |||
| 9a2160fa55 | |||
| 2de4118243 | |||
| 239b7befdd | |||
| 09e974d483 | |||
| e5ef4fa99a | |||
| 037bcd942c | |||
| c2e7507ad4 | |||
| 3aa2b6a637 | |||
| 555aa21905 | |||
| e7ae3bf3d6 | |||
| b932c048ac | |||
| e85829450d | |||
| effc5d24fa | |||
| 18ed3132d2 | |||
| 9b459eca88 | |||
| 70fedd0f79 | |||
| bb103b29bf | |||
| 248e76c4df | |||
| 803d5c35f3 | |||
| 7fd8c0f85c | |||
| 44c3a5abc3 | |||
| 6909a76201 | |||
| 045533716b | |||
| 3c0ff914ac | |||
| 2bc4be4e32 | |||
| c67abd614f | |||
| 6fa7cd3dbc | |||
| 94744ba41a | |||
| 4965ec42d2 | |||
| 73aa7041bf | |||
| 7c1f760024 | |||
| da461f3cbf | |||
| 5b800f0932 | |||
| 8427f70493 | |||
| 7a7992085b | |||
| 1286211f57 | |||
| 6d531ad7b8 | |||
| 762b424a52 | |||
| de1cb38769 | |||
| c802f5430d | |||
| cff8991a50 | |||
| f3f8d8fff4 | |||
| 26df46ee59 | |||
| c3f687ac22 | |||
| 04437e313d | |||
| 038bededba | |||
| d03308be0c | |||
| c6bc0034d0 | |||
| 70e132244a | |||
| 47e9038d23 | |||
| 432cf22a6a | |||
| 2914006fe0 | |||
| 7329ff5468 | |||
| 541d1df486 | |||
| 3b00ff9138 | |||
| 91276c5721 | |||
| 0b4167526d | |||
| fd5fd26902 | |||
| 3bbaacbe15 | |||
| a10314c6b3 | |||
| 70f2c2a709 | |||
| 280d074103 | |||
| 32b14baf8a | |||
| 2d9045fce8 | |||
| 355f66348c | |||
| 8693e47e6a | |||
| cec8c7d7f8 | |||
| 4d0ec37267 | |||
| e7f720ea56 | |||
| 4ae17bf1e2 | |||
| 8a49eea74b | |||
| b4245a48df | |||
| 4e0f6076be | |||
| 726efc6a32 | |||
| bd45912b99 | |||
| 15dac210f0 | |||
| 112b3e5b3b | |||
| 32d669275b | |||
| 4098b72210 | |||
| 46450b8d33 | |||
| 13ac9cab21 | |||
| 66aa4c0bf4 | |||
| 247181536f | |||
| 07bf813fb5 | |||
| 8958217ad5 | |||
| ac5bc615b0 | |||
| 8063dfc61a | |||
| 6278bc829e | |||
| 3f532cb6a6 | |||
| e6c9053f9e | |||
| 43ed4143c4 | |||
| f4c98b4d4c | |||
| e1e0fd7543 | |||
| df8d3d1287 | |||
| 619d3de8bd | |||
| ecff8309a3 | |||
| dcf2a590f5 | |||
| 54aa619459 | |||
| fb22be5817 | |||
| 7f301dd8ef | |||
| 8095341a01 | |||
| 69db16a46a | |||
| ce78f9af4e | |||
| 9239bf718e | |||
| 7a6d45bc8a | |||
| e74ff409e0 | |||
| 7a888271f5 | |||
| 9d119a86ae | |||
| b2e85e26f4 | |||
| dd8a29da99 | |||
| 27df5199d9 | |||
| 35fad35a48 | |||
| 733e7c9e95 | |||
| 0af4d764d6 | |||
| e64afa455c | |||
| 1711b929b6 | |||
| c091c0a588 | |||
| 1aa162e030 | |||
| cf5c8f1686 | |||
| 4ec2cee000 | |||
| 99f536f830 | |||
| 5ebf66748b | |||
| 781d056280 | |||
| 5aefd6ac31 | |||
| 6c663dfd5e | |||
| 33437bc6e7 | |||
| 23114d3364 | |||
| 997c8811d6 | |||
| e42389f9d7 | |||
| ff38f0a32c | |||
| a5cfbab3c8 | |||
| ac3cd6e83c | |||
| 082ab86f5f | |||
| 6aa196c8dc | |||
| a0dd7dcd49 | |||
| e977c11111 | |||
| 5f063a80bd | |||
| 5d8e1c9279 | |||
| 0a049c7d86 | |||
| d0cfec7ab9 | |||
| a608160027 | |||
| 3f04a7fbf2 | |||
| 5994430b84 | |||
| a9e879b316 | |||
| 3e2f37a69a | |||
| 4f044b1d67 | |||
| 4157f563b4 | |||
| 051da7efe3 | |||
| 25f560a62c | |||
| a09ad90a72 | |||
| 10b34e36b9 | |||
| b5269db959 | |||
| 6db94571d7 | |||
| 97cfa65df7 | |||
| 911c8eb000 | |||
| ebcebeeb6b | |||
| f533b5837f | |||
| 8279201ce6 | |||
| 23fdab00a8 | |||
| 623e2ed29f | |||
| 9d72daf4ce | |||
| 6dd55af6c9 | |||
| 3eb08ed9b1 | |||
| 5eeadc2642 | |||
| 3aee6573dc | |||
| 9cc645141d | |||
| 0893567db9 | |||
| 8abe69b499 | |||
| 761702fd19 | |||
| 9606d572ed | |||
| cbcdf2c609 | |||
| 038de04d7b | |||
| 6b3cc75be0 | |||
| 7ffcccfa5c | |||
| cc8accfd53 | |||
| 948ab03e7e | |||
| 5797fb97e9 | |||
| 3892e58ad7 | |||
| d20e261199 | |||
| f622dbcf39 | |||
| dccf535f8e | |||
| 9c5c81b0da | |||
| d6cd59f122 | |||
| bc8ed3c4ba | |||
| b9bd76ca14 | |||
| 6ebaf9ac71 | |||
| f90d34b498 | |||
| f68cce8e64 | |||
| 09b6a95551 | |||
| 50c9636d87 | |||
| 0661cfef7a | |||
| a827aa815d | |||
| b877031d80 | |||
| dd861b992f | |||
| eb63ea1e18 | |||
| 2f4bd358f1 | |||
| 8a8b30eac1 | |||
| 2fa0e1396b | |||
| 1c2bec0f82 | |||
| ec870fba9a | |||
| df1430265c | |||
| 4c69e228b3 | |||
| 790b79750b | |||
| cfbb8c930f | |||
| baec0d4de9 | |||
| c21b99b912 | |||
| 93a00d7dde | |||
| 61e8c18350 | |||
| 8afcd0f633 | |||
| 91ca929dc7 | |||
| 84e00adc8a | |||
| 47c7126213 | |||
| a989ca2bf6 | |||
| 0fa3970deb | |||
| da6ea29f7a | |||
| 7297941b38 | |||
| f8a08cb90d | |||
| b15fd2be2a | |||
| e588ac237c | |||
| 5df2da5b97 | |||
| 11b986b3fb | |||
| 296f927f24 | |||
| 0032903a5b | |||
| 47195057e9 | |||
| 6edbfa924d | |||
| 1e508343e1 | |||
| 2e0b4cfde0 | |||
| 10f55fe6c5 | |||
| d3ccbd6350 | |||
| 0cfe7d386d | |||
| 0c6f5023c3 | |||
| 06dd08256f | |||
| 2b22290ce0 | |||
| d8e82bc06d | |||
| 086b56824c | |||
| 5a0905ba2a | |||
| a8f12a63fd | |||
| 69ae2380c6 | |||
| 27261e40a6 | |||
| e3f813c33b | |||
| c607a2652b | |||
| 3d45e3d749 | |||
| 742369d35a | |||
| bfe2fe0af4 | |||
| a8652f4f0f | |||
| 2f726b241e | |||
| a597a57595 | |||
| ae65f3e237 | |||
| 34868b106a | |||
| 1f16b7fe74 | |||
| b88be22165 | |||
| d8c6d7d6b5 | |||
| 40828ce5fe | |||
| ffa443afed | |||
| 70e500cad9 | |||
| 4cb1c05c9e | |||
| c47aafa37c | |||
| cfbca8a2f2 | |||
| 0fe5609874 | |||
| 22d33baca2 | |||
| b0e96aaebb | |||
| 8310e0b59b | |||
| 26dd972adb | |||
| 61c7a1b856 | |||
| 374ee287d8 | |||
| a4d83661d7 | |||
| 8363cd093d | |||
| 6c5a3195db | |||
| 073d1ed354 | |||
| 3d446433ec | |||
| 1fe0fd12d3 | |||
| dafb4e504a | |||
| 68cf1601d3 | |||
| 61f412187d | |||
| 05ccd0aa35 | |||
| f690372b68 | |||
| 8b3e94a357 | |||
| 437f9162d0 | |||
| 4f065f12f5 | |||
| 228b768db6 | |||
| 027827cc1d | |||
| 72a8639b68 | |||
| 99abb8b650 | |||
| 3a1e648158 | |||
| 46c759c165 | |||
| 179a619c21 | |||
| 452e8fd968 | |||
| 8b793f7ec6 | |||
| af35d3a3cc | |||
| 3b457143d2 | |||
| ab656f2c2f | |||
| 64fc2193dc | |||
| dd732028f5 | |||
| 414919138b | |||
| db7c8ca910 | |||
| f863ffc965 | |||
| 400d483e87 | |||
| d1695758b2 | |||
| 53a0cf8b95 | |||
| 5eeabc2a44 | |||
| 18551e820c | |||
| e41e160263 | |||
| b89fb2a4a1 | |||
| 5340b0e221 |
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16 -b auto -l 1319 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.31
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.47
|
||||||
|
limit: 1319
|
||||||
|
num_fewshot: 5
|
||||||
@ -4,7 +4,7 @@ Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
|
|||||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||||
Minitron-4B-Base-FP8.yaml
|
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
|
||||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||||
Meta-Llama-3-8B-QQQ.yaml
|
Meta-Llama-3-8B-QQQ.yaml
|
||||||
|
|||||||
@ -10,15 +10,24 @@ set -x
|
|||||||
set -o pipefail
|
set -o pipefail
|
||||||
|
|
||||||
check_gpus() {
|
check_gpus() {
|
||||||
# check the number of GPUs and GPU type.
|
if command -v nvidia-smi; then
|
||||||
declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l)
|
# check the number of GPUs and GPU type.
|
||||||
|
declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l)
|
||||||
|
elif command -v amd-smi; then
|
||||||
|
declare -g gpu_count=$(amd-smi list | grep 'GPU' | wc -l)
|
||||||
|
fi
|
||||||
|
|
||||||
if [[ $gpu_count -gt 0 ]]; then
|
if [[ $gpu_count -gt 0 ]]; then
|
||||||
echo "GPU found."
|
echo "GPU found."
|
||||||
else
|
else
|
||||||
echo "Need at least 1 GPU to run benchmarking."
|
echo "Need at least 1 GPU to run benchmarking."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')
|
if command -v nvidia-smi; then
|
||||||
|
declare -g gpu_type=$(nvidia-smi --query-gpu=name --format=csv,noheader | awk '{print $2}')
|
||||||
|
elif command -v amd-smi; then
|
||||||
|
declare -g gpu_type=$(amd-smi static -g 0 -a | grep 'MARKET_NAME' | awk '{print $2}')
|
||||||
|
fi
|
||||||
echo "GPU type is $gpu_type"
|
echo "GPU type is $gpu_type"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,9 +99,15 @@ kill_gpu_processes() {
|
|||||||
|
|
||||||
|
|
||||||
# wait until GPU memory usage smaller than 1GB
|
# wait until GPU memory usage smaller than 1GB
|
||||||
while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
|
if command -v nvidia-smi; then
|
||||||
sleep 1
|
while [ "$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1)" -ge 1000 ]; do
|
||||||
done
|
sleep 1
|
||||||
|
done
|
||||||
|
elif command -v amd-smi; then
|
||||||
|
while [ "$(amd-smi metric -g 0 | grep 'USED_VRAM' | awk '{print $2}')" -ge 1000 ]; do
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
# remove vllm config file
|
# remove vllm config file
|
||||||
rm -rf ~/.config/vllm
|
rm -rf ~/.config/vllm
|
||||||
@ -361,7 +376,7 @@ main() {
|
|||||||
# get the current IP address, required by benchmark_serving.py
|
# get the current IP address, required by benchmark_serving.py
|
||||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||||
# turn of the reporting of the status of each request, to clean up the terminal output
|
# turn of the reporting of the status of each request, to clean up the terminal output
|
||||||
export VLLM_LOG_LEVEL="WARNING"
|
export VLLM_LOGGING_LEVEL="WARNING"
|
||||||
|
|
||||||
# prepare for benchmarking
|
# prepare for benchmarking
|
||||||
cd benchmarks || exit 1
|
cd benchmarks || exit 1
|
||||||
|
|||||||
@ -63,10 +63,12 @@
|
|||||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
"disable_log_requests": "",
|
"disable_log_requests": "",
|
||||||
"tensor_parallel_size": 4,
|
"tensor_parallel_size": 4,
|
||||||
"swap_space": 16,
|
"swap_space": 16,
|
||||||
"speculative_model": "turboderp/Qwama-0.5B-Instruct",
|
"speculative_config": {
|
||||||
"num_speculative_tokens": 4,
|
"model": "turboderp/Qwama-0.5B-Instruct",
|
||||||
"speculative_draft_tensor_parallel_size": 1
|
"num_speculative_tokens": 4,
|
||||||
|
"draft_tensor_parallel_size": 1
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"client_parameters": {
|
"client_parameters": {
|
||||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||||
|
|||||||
@ -3,10 +3,10 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
- "bash .buildkite/upload-wheels.sh"
|
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|
||||||
@ -14,10 +14,10 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
- "bash .buildkite/upload-wheels.sh"
|
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|
||||||
@ -31,10 +31,10 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
- "bash .buildkite/upload-wheels.sh"
|
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ steps:
|
|||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||||
|
|
||||||
- label: "Build and publish TPU release image"
|
- label: "Build and publish TPU release image"
|
||||||
@ -57,7 +57,7 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: tpu_queue_postmerge
|
queue: tpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f Dockerfile.tpu ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ."
|
||||||
- "docker push vllm/vllm-tpu:nightly"
|
- "docker push vllm/vllm-tpu:nightly"
|
||||||
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"
|
- "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT"
|
||||||
plugins:
|
plugins:
|
||||||
@ -82,7 +82,7 @@ steps:
|
|||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --progress plain -f Dockerfile.cpu ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
|||||||
@ -1,16 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# This script build the OpenVINO docker image and run the offline inference inside the container.
|
|
||||||
# It serves a sanity check for compilation and basic model usage.
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# Try building the docker image
|
|
||||||
docker build -t openvino-test -f Dockerfile.openvino .
|
|
||||||
|
|
||||||
# Setup cleanup
|
|
||||||
remove_docker_container() { docker rm -f openvino-test || true; }
|
|
||||||
trap remove_docker_container EXIT
|
|
||||||
remove_docker_container
|
|
||||||
|
|
||||||
# Run the image and launch offline inference
|
|
||||||
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
# Build the docker image.
|
|
||||||
docker build -f Dockerfile.tpu -t vllm-tpu .
|
|
||||||
|
|
||||||
# Set up cleanup.
|
|
||||||
remove_docker_container() { docker rm -f tpu-test || true; }
|
|
||||||
trap remove_docker_container EXIT
|
|
||||||
# Remove the container that might not be cleaned up in the previous run.
|
|
||||||
remove_docker_container
|
|
||||||
|
|
||||||
# For HF_TOKEN.
|
|
||||||
source /etc/environment
|
|
||||||
# Run a simple end-to-end example.
|
|
||||||
docker run --privileged --net host --shm-size=16G -it \
|
|
||||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
|
||||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
|
||||||
&& python3 -m pip install pytest \
|
|
||||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
|
||||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
|
||||||
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
|
||||||
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
|
||||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
# Build the docker image.
|
|
||||||
docker build -f Dockerfile.tpu -t vllm-tpu .
|
|
||||||
|
|
||||||
# Set up cleanup.
|
|
||||||
remove_docker_container() { docker rm -f tpu-test || true; }
|
|
||||||
trap remove_docker_container EXIT
|
|
||||||
# Remove the container that might not be cleaned up in the previous run.
|
|
||||||
remove_docker_container
|
|
||||||
|
|
||||||
# For HF_TOKEN.
|
|
||||||
source /etc/environment
|
|
||||||
# Run a simple end-to-end example.
|
|
||||||
docker run --privileged --net host --shm-size=16G -it \
|
|
||||||
-e "HF_TOKEN=$HF_TOKEN" -e "VLLM_USE_V1=1" --name tpu-test \
|
|
||||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
|
||||||
&& python3 -m pip install pytest \
|
|
||||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
|
||||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
|
||||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
|
||||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
|
||||||
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
|
||||||
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
|
||||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
|
|
||||||
@ -105,19 +105,33 @@ fi
|
|||||||
if [[ $commands == *" entrypoints/openai "* ]]; then
|
if [[ $commands == *" entrypoints/openai "* ]]; then
|
||||||
commands=${commands//" entrypoints/openai "/" entrypoints/openai \
|
commands=${commands//" entrypoints/openai "/" entrypoints/openai \
|
||||||
--ignore=entrypoints/openai/test_audio.py \
|
--ignore=entrypoints/openai/test_audio.py \
|
||||||
--ignore=entrypoints/openai/test_chat.py \
|
|
||||||
--ignore=entrypoints/openai/test_shutdown.py \
|
--ignore=entrypoints/openai/test_shutdown.py \
|
||||||
--ignore=entrypoints/openai/test_completion.py \
|
--ignore=entrypoints/openai/test_completion.py \
|
||||||
--ignore=entrypoints/openai/test_sleep.py \
|
--ignore=entrypoints/openai/test_sleep.py \
|
||||||
--ignore=entrypoints/openai/test_models.py \
|
--ignore=entrypoints/openai/test_models.py \
|
||||||
|
--ignore=entrypoints/openai/test_lora_adapters.py \
|
||||||
|
--ignore=entrypoints/openai/test_return_tokens_as_ids.py \
|
||||||
|
--ignore=entrypoints/openai/test_root_path.py \
|
||||||
|
--ignore=entrypoints/openai/test_tokenization.py \
|
||||||
--ignore=entrypoints/openai/test_prompt_validation.py "}
|
--ignore=entrypoints/openai/test_prompt_validation.py "}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
#ignore certain Entrypoints/llm tests
|
#ignore certain Entrypoints/llm tests
|
||||||
if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
if [[ $commands == *" entrypoints/llm "* ]]; then
|
||||||
commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
commands=${commands//" entrypoints/llm "/" entrypoints/llm \
|
||||||
|
--ignore=entrypoints/llm/test_chat.py \
|
||||||
|
--ignore=entrypoints/llm/test_accuracy.py \
|
||||||
|
--ignore=entrypoints/llm/test_init.py \
|
||||||
|
--ignore=entrypoints/llm/test_generate_multiple_loras.py \
|
||||||
|
--ignore=entrypoints/llm/test_prompt_validation.py "}
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
#Obsolete currently
|
||||||
|
##ignore certain Entrypoints/llm tests
|
||||||
|
#if [[ $commands == *" && pytest -v -s entrypoints/llm/test_guided_generate.py"* ]]; then
|
||||||
|
# commands=${commands//" && pytest -v -s entrypoints/llm/test_guided_generate.py"/" "}
|
||||||
|
#fi
|
||||||
|
|
||||||
# --ignore=entrypoints/openai/test_encoder_decoder.py \
|
# --ignore=entrypoints/openai/test_encoder_decoder.py \
|
||||||
# --ignore=entrypoints/openai/test_embedding.py \
|
# --ignore=entrypoints/openai/test_embedding.py \
|
||||||
# --ignore=entrypoints/openai/test_oot_registration.py
|
# --ignore=entrypoints/openai/test_oot_registration.py
|
||||||
@ -134,9 +148,10 @@ if [[ $commands == *"--shard-id="* ]]; then
|
|||||||
# assign shard-id for each shard
|
# assign shard-id for each shard
|
||||||
commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
|
commands_gpu=${commands//"--shard-id= "/"--shard-id=${GPU} "}
|
||||||
echo "Shard ${GPU} commands:$commands_gpu"
|
echo "Shard ${GPU} commands:$commands_gpu"
|
||||||
|
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
||||||
docker run \
|
docker run \
|
||||||
--device /dev/kfd --device /dev/dri \
|
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
||||||
--network host \
|
--network=host \
|
||||||
--shm-size=16gb \
|
--shm-size=16gb \
|
||||||
--rm \
|
--rm \
|
||||||
-e HIP_VISIBLE_DEVICES="${GPU}" \
|
-e HIP_VISIBLE_DEVICES="${GPU}" \
|
||||||
@ -163,9 +178,10 @@ if [[ $commands == *"--shard-id="* ]]; then
|
|||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
else
|
else
|
||||||
|
echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES"
|
||||||
docker run \
|
docker run \
|
||||||
--device /dev/kfd --device /dev/dri \
|
--device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \
|
||||||
--network host \
|
--network=host \
|
||||||
--shm-size=16gb \
|
--shm-size=16gb \
|
||||||
--rm \
|
--rm \
|
||||||
-e HIP_VISIBLE_DEVICES=0 \
|
-e HIP_VISIBLE_DEVICES=0 \
|
||||||
38
.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
Executable file
38
.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh
Executable file
@ -0,0 +1,38 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This script build the CPU docker image and run the offline inference inside the container.
|
||||||
|
# It serves a sanity check for compilation and basic model usage.
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Setup cleanup
|
||||||
|
remove_docker_container() { podman rm -f cpu-test-ubi9-ppc || true; podman system prune -f; }
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
remove_docker_container
|
||||||
|
|
||||||
|
# Try building the docker image
|
||||||
|
podman build -t cpu-test-ubi9-ppc -f docker/Dockerfile.ppc64le .
|
||||||
|
|
||||||
|
# Run the image
|
||||||
|
podman run -itd --entrypoint /bin/bash -v /tmp/:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --name cpu-test-ubi9-ppc cpu-test-ubi9-ppc
|
||||||
|
|
||||||
|
function cpu_tests() {
|
||||||
|
|
||||||
|
# offline inference
|
||||||
|
podman exec cpu-test-ubi9-ppc bash -c "
|
||||||
|
set -e
|
||||||
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||||
|
|
||||||
|
# Run basic model test
|
||||||
|
podman exec cpu-test-ubi9-ppc bash -c "
|
||||||
|
set -e
|
||||||
|
pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib
|
||||||
|
pip install sentence-transformers datamodel_code_generator
|
||||||
|
pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach]
|
||||||
|
pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]
|
||||||
|
pytest -v -s tests/models/encoder_decoder/language -m cpu_model"
|
||||||
|
}
|
||||||
|
|
||||||
|
# All of CPU tests are expected to be finished less than 40 mins.
|
||||||
|
export -f cpu_tests
|
||||||
|
timeout 40m bash -c cpu_tests
|
||||||
|
|
||||||
@ -10,5 +10,4 @@ trap remove_docker_container EXIT
|
|||||||
remove_docker_container
|
remove_docker_container
|
||||||
|
|
||||||
# Try building the docker image
|
# Try building the docker image
|
||||||
docker build -t cpu-test -f Dockerfile.ppc64le .
|
docker build -t cpu-test -f docker/Dockerfile.s390x .
|
||||||
|
|
||||||
@ -8,15 +8,19 @@ set -ex
|
|||||||
CORE_RANGE=${CORE_RANGE:-48-95}
|
CORE_RANGE=${CORE_RANGE:-48-95}
|
||||||
NUMA_NODE=${NUMA_NODE:-1}
|
NUMA_NODE=${NUMA_NODE:-1}
|
||||||
|
|
||||||
# Try building the docker image
|
|
||||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test-"$BUILDKITE_BUILD_NUMBER" -f Dockerfile.cpu .
|
|
||||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 -f Dockerfile.cpu .
|
|
||||||
|
|
||||||
# Setup cleanup
|
# Setup cleanup
|
||||||
remove_docker_container() { set -e; docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; }
|
remove_docker_container() {
|
||||||
|
set -e;
|
||||||
|
docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true;
|
||||||
|
docker image rm cpu-test-"$BUILDKITE_BUILD_NUMBER" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 || true;
|
||||||
|
}
|
||||||
trap remove_docker_container EXIT
|
trap remove_docker_container EXIT
|
||||||
remove_docker_container
|
remove_docker_container
|
||||||
|
|
||||||
|
# Try building the docker image
|
||||||
|
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$BUILDKITE_BUILD_NUMBER" --target vllm-test -f docker/Dockerfile.cpu .
|
||||||
|
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
|
||||||
|
|
||||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||||
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"
|
--cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"
|
||||||
@ -36,8 +40,8 @@ function cpu_tests() {
|
|||||||
# Run basic model test
|
# Run basic model test
|
||||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||||
set -e
|
set -e
|
||||||
pip install -r vllm/requirements/test.txt
|
pytest -v -s tests/kernels/test_cache.py -m cpu_model
|
||||||
pip install -r vllm/requirements/cpu.txt
|
pytest -v -s tests/kernels/test_mla_decode_cpu.py -m cpu_model
|
||||||
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||||
pytest -v -s tests/models/embedding/language -m cpu_model
|
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||||
@ -9,11 +9,13 @@ python3 use_existing_torch.py
|
|||||||
|
|
||||||
# Try building the docker image
|
# Try building the docker image
|
||||||
DOCKER_BUILDKIT=1 docker build . \
|
DOCKER_BUILDKIT=1 docker build . \
|
||||||
|
--file docker/Dockerfile \
|
||||||
--target vllm-openai \
|
--target vllm-openai \
|
||||||
--platform "linux/arm64" \
|
--platform "linux/arm64" \
|
||||||
-t gh200-test \
|
-t gh200-test \
|
||||||
--build-arg max_jobs=66 \
|
--build-arg max_jobs=66 \
|
||||||
--build-arg nvcc_threads=2 \
|
--build-arg nvcc_threads=2 \
|
||||||
|
--build-arg RUN_WHEEL_CHECK=false \
|
||||||
--build-arg torch_cuda_arch_list="9.0+PTX" \
|
--build-arg torch_cuda_arch_list="9.0+PTX" \
|
||||||
--build-arg vllm_fa_cmake_gpu_arches="90-real"
|
--build-arg vllm_fa_cmake_gpu_arches="90-real"
|
||||||
|
|
||||||
@ -23,6 +25,6 @@ trap remove_docker_container EXIT
|
|||||||
remove_docker_container
|
remove_docker_container
|
||||||
|
|
||||||
# Run the image and test offline inference
|
# Run the image and test offline inference
|
||||||
docker run -e HF_TOKEN -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
|
docker run -e HF_TOKEN -e VLLM_WORKER_MULTIPROC_METHOD=spawn -v /root/.cache/huggingface:/root/.cache/huggingface --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
|
||||||
python3 examples/offline_inference/basic/generate.py --model meta-llama/Llama-3.2-1B
|
python3 examples/offline_inference/basic/generate.py --model meta-llama/Llama-3.2-1B
|
||||||
'
|
'
|
||||||
@ -5,7 +5,7 @@
|
|||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
# Try building the docker image
|
# Try building the docker image
|
||||||
docker build -t hpu-test-env -f Dockerfile.hpu .
|
docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||||
|
|
||||||
# Setup cleanup
|
# Setup cleanup
|
||||||
# certain versions of HPU software stack have a bug that can
|
# certain versions of HPU software stack have a bug that can
|
||||||
@ -35,7 +35,7 @@ else
|
|||||||
date "+%s" > /tmp/neuron-docker-build-timestamp
|
date "+%s" > /tmp/neuron-docker-build-timestamp
|
||||||
fi
|
fi
|
||||||
|
|
||||||
docker build -t "${image_name}" -f Dockerfile.neuron .
|
docker build -t "${image_name}" -f docker/Dockerfile.neuron .
|
||||||
|
|
||||||
# Setup cleanup
|
# Setup cleanup
|
||||||
remove_docker_container() {
|
remove_docker_container() {
|
||||||
49
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
Executable file
49
.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh
Executable file
@ -0,0 +1,49 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -xue
|
||||||
|
|
||||||
|
# Build the docker image.
|
||||||
|
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
|
||||||
|
|
||||||
|
# Set up cleanup.
|
||||||
|
remove_docker_container() { docker rm -f tpu-test || true; }
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
# Remove the container that might not be cleaned up in the previous run.
|
||||||
|
remove_docker_container
|
||||||
|
|
||||||
|
# For HF_TOKEN.
|
||||||
|
source /etc/environment
|
||||||
|
# Run a simple end-to-end example.
|
||||||
|
docker run --privileged --net host --shm-size=16G -it \
|
||||||
|
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||||
|
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||||
|
&& python3 -m pip install pytest tpu-info \
|
||||||
|
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||||
|
&& export VLLM_USE_V1=1 \
|
||||||
|
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||||
|
&& echo HARDWARE \
|
||||||
|
&& tpu-info \
|
||||||
|
&& echo TEST_0 \
|
||||||
|
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
||||||
|
&& echo TEST_1 \
|
||||||
|
&& pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \
|
||||||
|
&& echo TEST_2 \
|
||||||
|
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \
|
||||||
|
&& echo TEST_3 \
|
||||||
|
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||||
|
&& echo TEST_4 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||||
|
&& echo TEST_5 \
|
||||||
|
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||||
|
&& echo TEST_6 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
||||||
|
&& echo TEST_7 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
|
||||||
|
&& echo TEST_8 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
|
||||||
|
&& echo TEST_9 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||||
|
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||||
@ -8,14 +8,15 @@ image_name="xpu/vllm-ci:${BUILDKITE_COMMIT}"
|
|||||||
container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||||
|
|
||||||
# Try building the docker image
|
# Try building the docker image
|
||||||
docker build -t ${image_name} -f Dockerfile.xpu .
|
docker build -t ${image_name} -f docker/Dockerfile.xpu .
|
||||||
|
|
||||||
# Setup cleanup
|
# Setup cleanup
|
||||||
remove_docker_container() {
|
remove_docker_container() {
|
||||||
docker rm -f "${container_name}" || docker image rm -f "${image_name}" || true;
|
docker rm -f "${container_name}" || true;
|
||||||
|
docker image rm -f "${image_name}" || true;
|
||||||
|
docker system prune -f || true;
|
||||||
}
|
}
|
||||||
trap remove_docker_container EXIT
|
trap remove_docker_container EXIT
|
||||||
remove_docker_container
|
|
||||||
|
|
||||||
# Run the image and test offline inference/tensor parallel
|
# Run the image and test offline inference/tensor parallel
|
||||||
docker run \
|
docker run \
|
||||||
@ -25,6 +26,6 @@ docker run \
|
|||||||
--name "${container_name}" \
|
--name "${container_name}" \
|
||||||
"${image_name}" \
|
"${image_name}" \
|
||||||
sh -c '
|
sh -c '
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||||
'
|
'
|
||||||
@ -5,8 +5,8 @@
|
|||||||
set -ex
|
set -ex
|
||||||
set -o pipefail
|
set -o pipefail
|
||||||
|
|
||||||
# cd into parent directory of this file
|
# cd 2 levels into the working directory
|
||||||
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
cd "$(dirname "${BASH_SOURCE[0]}")/../.."
|
||||||
|
|
||||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
|
|
||||||
@ -3,7 +3,7 @@
|
|||||||
set -euox pipefail
|
set -euox pipefail
|
||||||
|
|
||||||
if [[ $# -lt 4 ]]; then
|
if [[ $# -lt 4 ]]; then
|
||||||
echo "Usage: .buildkite/run-multi-node-test.sh WORKING_DIR NUM_NODES NUM_GPUS DOCKER_IMAGE COMMAND1 COMMAND2 ... COMMANDN"
|
echo "Usage: .buildkite/scripts/run-multi-node-test.sh WORKING_DIR NUM_NODES NUM_GPUS DOCKER_IMAGE COMMAND1 COMMAND2 ... COMMANDN"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ steps:
|
|||||||
- label: Entrypoints Test # 40min
|
- label: Entrypoints Test # 40min
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
fast_check: true
|
fast_check: true
|
||||||
mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/entrypoints/llm
|
- tests/entrypoints/llm
|
||||||
@ -118,7 +118,7 @@ steps:
|
|||||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
|
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_openai_schema.py
|
||||||
- pytest -v -s entrypoints/test_chat_utils.py
|
- pytest -v -s entrypoints/test_chat_utils.py
|
||||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||||
|
|
||||||
@ -135,8 +135,14 @@ steps:
|
|||||||
- examples/offline_inference/rlhf.py
|
- examples/offline_inference/rlhf.py
|
||||||
- examples/offline_inference/rlhf_colocate.py
|
- examples/offline_inference/rlhf_colocate.py
|
||||||
- tests/examples/offline_inference/data_parallel.py
|
- tests/examples/offline_inference/data_parallel.py
|
||||||
|
- tests/v1/test_async_llm_dp.py
|
||||||
commands:
|
commands:
|
||||||
|
# test with tp=2 and external_dp=2
|
||||||
|
- VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
|
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||||
|
# test with internal dp
|
||||||
- python3 ../examples/offline_inference/data_parallel.py
|
- python3 ../examples/offline_inference/data_parallel.py
|
||||||
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
- pytest -v -s distributed/test_utils.py
|
- pytest -v -s distributed/test_utils.py
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
- pytest -v -s distributed/test_pynccl.py
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
@ -149,6 +155,7 @@ steps:
|
|||||||
- popd
|
- popd
|
||||||
|
|
||||||
- label: Metrics, Tracing Test # 10min
|
- label: Metrics, Tracing Test # 10min
|
||||||
|
mirror_hardwares: [amd]
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -156,18 +163,13 @@ steps:
|
|||||||
- tests/tracing
|
- tests/tracing
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s metrics
|
- pytest -v -s metrics
|
||||||
- "pip install \
|
|
||||||
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
|
||||||
'opentelemetry-api>=1.26.0,<1.27.0' \
|
|
||||||
'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \
|
|
||||||
'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'"
|
|
||||||
- pytest -v -s tracing
|
- pytest -v -s tracing
|
||||||
|
|
||||||
##### fast check tests #####
|
##### fast check tests #####
|
||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
|
|
||||||
- label: Regression Test # 5min
|
- label: Regression Test # 5min
|
||||||
mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/test_regression
|
- tests/test_regression
|
||||||
@ -198,7 +200,6 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
# split the test to avoid interference
|
# split the test to avoid interference
|
||||||
- pytest -v -s v1/core
|
- pytest -v -s v1/core
|
||||||
- pytest -v -s v1/entrypoints
|
|
||||||
- pytest -v -s v1/engine
|
- pytest -v -s v1/engine
|
||||||
- pytest -v -s v1/entrypoints
|
- pytest -v -s v1/entrypoints
|
||||||
- pytest -v -s v1/sample
|
- pytest -v -s v1/sample
|
||||||
@ -279,13 +280,21 @@ steps:
|
|||||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||||
|
|
||||||
- label: LoRA Test %N # 15min each
|
- label: LoRA Test %N # 15min each
|
||||||
mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
- tests/lora
|
- tests/lora
|
||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
|
- label: PyTorch Compilation Unit Tests
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/compile
|
||||||
|
commands:
|
||||||
|
- pytest -v -s compile/test_pass_manager.py
|
||||||
|
- pytest -v -s compile/test_fusion.py
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
@ -304,7 +313,7 @@ steps:
|
|||||||
- pytest -v -s compile/test_full_graph.py
|
- pytest -v -s compile/test_full_graph.py
|
||||||
|
|
||||||
- label: Kernels Test %N # 1h each
|
- label: Kernels Test %N # 1h each
|
||||||
mirror_hardwares: [amd]
|
# mirror_hardwares: [amd]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- vllm/attention
|
- vllm/attention
|
||||||
@ -314,7 +323,7 @@ steps:
|
|||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
- label: Tensorizer Test # 11min
|
- label: Tensorizer Test # 11min
|
||||||
mirror_hardwares: [amd]
|
# mirror_hardwares: [amd]
|
||||||
soft_fail: true
|
soft_fail: true
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor/model_loader
|
- vllm/model_executor/model_loader
|
||||||
@ -330,7 +339,14 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- benchmarks/
|
- benchmarks/
|
||||||
commands:
|
commands:
|
||||||
- bash run-benchmarks.sh
|
- bash scripts/run-benchmarks.sh
|
||||||
|
|
||||||
|
- label: Benchmarks CLI Test # 10min
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/benchmarks/
|
||||||
|
commands:
|
||||||
|
- pytest -v -s benchmarks/
|
||||||
|
|
||||||
- label: Quantization Test # 33min
|
- label: Quantization Test # 33min
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -365,12 +381,14 @@ steps:
|
|||||||
|
|
||||||
- label: OpenAI-Compatible Tool Use # 20 min
|
- label: OpenAI-Compatible Tool Use # 20 min
|
||||||
fast_check: false
|
fast_check: false
|
||||||
mirror_hardwares: [ amd ]
|
#mirror_hardwares: [ amd ]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/tool_use
|
- tests/tool_use
|
||||||
|
- tests/mistral_tool_use
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s tool_use
|
- pytest -v -s tool_use
|
||||||
|
- pytest -v -s mistral_tool_use
|
||||||
|
|
||||||
##### models test #####
|
##### models test #####
|
||||||
|
|
||||||
@ -382,7 +400,9 @@ steps:
|
|||||||
- pytest -v -s models/test_transformers.py
|
- pytest -v -s models/test_transformers.py
|
||||||
- pytest -v -s models/test_registry.py
|
- pytest -v -s models/test_registry.py
|
||||||
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
# V1 Test: https://github.com/vllm-project/vllm/issues/14531
|
||||||
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py
|
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'
|
||||||
|
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4'
|
||||||
|
- VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2'
|
||||||
|
|
||||||
- label: Language Models Test (Standard) # 32min
|
- label: Language Models Test (Standard) # 32min
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
@ -392,6 +412,8 @@ steps:
|
|||||||
- tests/models/embedding/language
|
- tests/models/embedding/language
|
||||||
- tests/models/encoder_decoder/language
|
- tests/models/encoder_decoder/language
|
||||||
commands:
|
commands:
|
||||||
|
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||||
|
- pip install causal-conv1d
|
||||||
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
- pytest -v -s models/decoder_only/language -m 'core_model or quant_model'
|
||||||
- pytest -v -s models/embedding/language -m core_model
|
- pytest -v -s models/embedding/language -m core_model
|
||||||
|
|
||||||
@ -403,6 +425,8 @@ steps:
|
|||||||
- tests/models/embedding/language
|
- tests/models/embedding/language
|
||||||
- tests/models/encoder_decoder/language
|
- tests/models/encoder_decoder/language
|
||||||
commands:
|
commands:
|
||||||
|
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||||
|
- pip install causal-conv1d
|
||||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||||
|
|
||||||
@ -419,11 +443,12 @@ steps:
|
|||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/multimodal
|
- pytest -v -s models/multimodal
|
||||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
- pytest -v -s models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||||
- pytest -v -s models/embedding/vision_language -m core_model
|
- pytest -v -s models/embedding/vision_language -m core_model
|
||||||
- pytest -v -s models/encoder_decoder/audio_language -m core_model
|
- pytest -v -s models/encoder_decoder/audio_language -m core_model
|
||||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||||
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||||
|
- pytest -v -s models/decoder_only/vision_language/test_interleaved.py
|
||||||
|
|
||||||
- label: Multi-Modal Models Test (Extended) 1 # 48m
|
- label: Multi-Modal Models Test (Extended) 1 # 48m
|
||||||
optional: true
|
optional: true
|
||||||
@ -437,10 +462,7 @@ steps:
|
|||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||||
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
- pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model'
|
||||||
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
|
- pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model'
|
||||||
# HACK - run phi3v tests separately to sidestep this transformers bug
|
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
||||||
# https://github.com/huggingface/transformers/issues/34307
|
|
||||||
- pytest -v -s models/decoder_only/vision_language/test_phi3v.py
|
|
||||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model'
|
|
||||||
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
- pytest -v -s models/embedding/vision_language -m 'not core_model'
|
||||||
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
|
- pytest -v -s models/encoder_decoder/language -m 'not core_model'
|
||||||
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
|
- pytest -v -s models/encoder_decoder/vision_language -m 'not core_model'
|
||||||
@ -456,6 +478,7 @@ steps:
|
|||||||
|
|
||||||
# This test is used only in PR development phase to test individual models and should never run on main
|
# This test is used only in PR development phase to test individual models and should never run on main
|
||||||
- label: Custom Models Test
|
- label: Custom Models Test
|
||||||
|
mirror_hardwares: [amd]
|
||||||
optional: true
|
optional: true
|
||||||
commands:
|
commands:
|
||||||
- echo 'Testing custom models...'
|
- echo 'Testing custom models...'
|
||||||
@ -467,6 +490,7 @@ steps:
|
|||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
|
|
||||||
- label: Distributed Comm Ops Test # 7min
|
- label: Distributed Comm Ops Test # 7min
|
||||||
|
mirror_hardwares: [amd]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
@ -509,10 +533,11 @@ steps:
|
|||||||
- vllm/worker/worker.py
|
- vllm/worker/worker.py
|
||||||
- vllm/worker/model_runner.py
|
- vllm/worker/model_runner.py
|
||||||
- entrypoints/llm/test_collective_rpc.py
|
- entrypoints/llm/test_collective_rpc.py
|
||||||
|
- tests/v1/test_async_llm_dp.py
|
||||||
|
- vllm/v1/engine/
|
||||||
commands:
|
commands:
|
||||||
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||||
- VLLM_USE_V1=1 torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
|
||||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
- pytest -v -s ./compile/test_wrapper.py
|
- pytest -v -s ./compile/test_wrapper.py
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||||
@ -527,6 +552,7 @@ steps:
|
|||||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||||
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||||
|
|
||||||
- label: Plugin Tests (2 GPUs) # 40min
|
- label: Plugin Tests (2 GPUs) # 40min
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
@ -589,14 +615,10 @@ steps:
|
|||||||
# FIXIT: find out which code initialize cuda before running the test
|
# FIXIT: find out which code initialize cuda before running the test
|
||||||
# before the fix, we need to use spawn to test it
|
# before the fix, we need to use spawn to test it
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
|
||||||
- pytest -v -s -x lora/test_long_context.py
|
|
||||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||||
# requires multi-GPU testing for validation.
|
# requires multi-GPU testing for validation.
|
||||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||||
- pytest -v -s -x lora/test_llama_tp.py
|
- pytest -v -s -x lora/test_llama_tp.py
|
||||||
- pytest -v -s -x lora/test_minicpmv_tp.py
|
|
||||||
- pytest -v -s -x lora/test_transfomers_model.py
|
|
||||||
|
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test # 33min
|
- label: Weight Loading Multiple GPU Test # 33min
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@ -14,7 +14,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@ -14,7 +14,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -14,7 +14,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
|||||||
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
2
.github/ISSUE_TEMPLATE/600-new-model.yml
vendored
@ -9,7 +9,7 @@ body:
|
|||||||
value: >
|
value: >
|
||||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||||
|
|
||||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model.
|
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/index.html first to understand how to add a new model.
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: The model to consider.
|
label: The model to consider.
|
||||||
|
|||||||
@ -35,7 +35,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
|||||||
28
.github/ISSUE_TEMPLATE/800-misc-discussion.yml
vendored
28
.github/ISSUE_TEMPLATE/800-misc-discussion.yml
vendored
@ -1,28 +0,0 @@
|
|||||||
name: 🎲 Misc/random discussions that do not fit into the above categories.
|
|
||||||
description: Submit a discussion as you like. Note that developers are heavily overloaded and we mainly rely on community users to answer these issues.
|
|
||||||
title: "[Misc]: "
|
|
||||||
labels: ["misc"]
|
|
||||||
|
|
||||||
body:
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: >
|
|
||||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
|
||||||
- type: textarea
|
|
||||||
attributes:
|
|
||||||
label: Anything you want to discuss about vllm.
|
|
||||||
description: >
|
|
||||||
Anything you want to discuss about vllm.
|
|
||||||
validations:
|
|
||||||
required: true
|
|
||||||
- type: markdown
|
|
||||||
attributes:
|
|
||||||
value: >
|
|
||||||
Thanks for contributing 🎉!
|
|
||||||
- type: checkboxes
|
|
||||||
id: askllm
|
|
||||||
attributes:
|
|
||||||
label: Before submitting a new issue...
|
|
||||||
options:
|
|
||||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
|
||||||
required: true
|
|
||||||
4
.github/ISSUE_TEMPLATE/config.yml
vendored
4
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1 +1,5 @@
|
|||||||
blank_issues_enabled: false
|
blank_issues_enabled: false
|
||||||
|
contact_links:
|
||||||
|
- name: Questions
|
||||||
|
url: https://discuss.vllm.ai
|
||||||
|
about: Ask questions and discuss with other vLLM community members
|
||||||
|
|||||||
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -3,4 +3,4 @@ FILL IN THE PR DESCRIPTION HERE
|
|||||||
FIX #xxxx (*link existing issues this PR will resolve*)
|
FIX #xxxx (*link existing issues this PR will resolve*)
|
||||||
|
|
||||||
<!--- pyml disable-next-line no-emphasis-as-heading -->
|
<!--- pyml disable-next-line no-emphasis-as-heading -->
|
||||||
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>**
|
**BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing/overview.html>** (anything written below this line will be removed by GitHub Actions)
|
||||||
|
|||||||
32
.github/mergify.yml
vendored
32
.github/mergify.yml
vendored
@ -19,7 +19,7 @@ pull_request_rules:
|
|||||||
- files~=\.buildkite/
|
- files~=\.buildkite/
|
||||||
- files~=^cmake/
|
- files~=^cmake/
|
||||||
- files=CMakeLists.txt
|
- files=CMakeLists.txt
|
||||||
- files~=^Dockerfile
|
- files~=^docker/Dockerfile
|
||||||
- files~=^requirements.*\.txt
|
- files~=^requirements.*\.txt
|
||||||
- files=setup.py
|
- files=setup.py
|
||||||
actions:
|
actions:
|
||||||
@ -88,6 +88,36 @@ pull_request_rules:
|
|||||||
add:
|
add:
|
||||||
- v1
|
- v1
|
||||||
|
|
||||||
|
- name: label-tpu
|
||||||
|
description: Automatically apply tpu label
|
||||||
|
# Keep this list in sync with `label-tpu-remove` conditions
|
||||||
|
conditions:
|
||||||
|
- or:
|
||||||
|
- files~=tpu.py
|
||||||
|
- files~=_tpu
|
||||||
|
- files~=tpu_
|
||||||
|
- files~=/tpu/
|
||||||
|
- files~=pallas
|
||||||
|
actions:
|
||||||
|
label:
|
||||||
|
add:
|
||||||
|
- tpu
|
||||||
|
|
||||||
|
- name: label-tpu-remove
|
||||||
|
description: Automatically remove tpu label
|
||||||
|
# Keep this list in sync with `label-tpu` conditions
|
||||||
|
conditions:
|
||||||
|
- and:
|
||||||
|
- -files~=tpu.py
|
||||||
|
- -files~=_tpu
|
||||||
|
- -files~=tpu_
|
||||||
|
- -files~=/tpu/
|
||||||
|
- -files~=pallas
|
||||||
|
actions:
|
||||||
|
label:
|
||||||
|
remove:
|
||||||
|
- tpu
|
||||||
|
|
||||||
- name: ping author on conflicts and add 'needs-rebase' label
|
- name: ping author on conflicts and add 'needs-rebase' label
|
||||||
conditions:
|
conditions:
|
||||||
- conflict
|
- conflict
|
||||||
|
|||||||
2
.github/workflows/lint-and-deploy.yaml
vendored
2
.github/workflows/lint-and-deploy.yaml
vendored
@ -50,7 +50,7 @@ jobs:
|
|||||||
uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
|
uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
|
||||||
|
|
||||||
- name: Build the Docker image vllm cpu
|
- name: Build the Docker image vllm cpu
|
||||||
run: docker buildx build -f Dockerfile.cpu -t vllm-cpu-env .
|
run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env .
|
||||||
|
|
||||||
- name: Configuration of docker images, network and namespace for the kind cluster
|
- name: Configuration of docker images, network and namespace for the kind cluster
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
6
.gitignore
vendored
6
.gitignore
vendored
@ -2,7 +2,8 @@
|
|||||||
/vllm/_version.py
|
/vllm/_version.py
|
||||||
|
|
||||||
# vllm-flash-attn built from source
|
# vllm-flash-attn built from source
|
||||||
vllm/vllm_flash_attn/
|
vllm/vllm_flash_attn/*
|
||||||
|
!vllm/vllm_flash_attn/fa_utils.py
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
@ -202,3 +203,6 @@ benchmarks/**/*.json
|
|||||||
# Linting
|
# Linting
|
||||||
actionlint
|
actionlint
|
||||||
shellcheck*/
|
shellcheck*/
|
||||||
|
|
||||||
|
# Ingore moe/marlin_moe gen code
|
||||||
|
csrc/moe/marlin_moe_wna16/kernel_*
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
default_install_hook_types:
|
||||||
|
- pre-commit
|
||||||
|
- commit-msg
|
||||||
default_stages:
|
default_stages:
|
||||||
- pre-commit # Run locally
|
- pre-commit # Run locally
|
||||||
- manual # Run in CI
|
- manual # Run in CI
|
||||||
@ -8,7 +11,6 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
args: [--in-place, --verbose]
|
args: [--in-place, --verbose]
|
||||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.9.3
|
rev: v0.9.3
|
||||||
hooks:
|
hooks:
|
||||||
@ -119,6 +121,12 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
- id: update-dockerfile-graph
|
||||||
|
name: Update Dockerfile dependency graph
|
||||||
|
entry: tools/update-dockerfile-graph.sh
|
||||||
|
language: script
|
||||||
|
files: ^docker/Dockerfile$
|
||||||
|
pass_filenames: false
|
||||||
# Keep `suggestion` last
|
# Keep `suggestion` last
|
||||||
- id: suggestion
|
- id: suggestion
|
||||||
name: Suggestion
|
name: Suggestion
|
||||||
|
|||||||
@ -34,7 +34,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
|||||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
||||||
|
|
||||||
# Supported AMD GPU architectures.
|
# Supported AMD GPU architectures.
|
||||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
|
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
||||||
|
|
||||||
#
|
#
|
||||||
# Supported/expected torch versions for CUDA/ROCm.
|
# Supported/expected torch versions for CUDA/ROCm.
|
||||||
@ -44,7 +44,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101")
|
|||||||
#
|
#
|
||||||
# Note: the CUDA torch version is derived from pyproject.toml and various
|
# Note: the CUDA torch version is derived from pyproject.toml and various
|
||||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||||
# versions are derived from Dockerfile.rocm
|
# versions are derived from docker/Dockerfile.rocm
|
||||||
#
|
#
|
||||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0")
|
set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0")
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0")
|
set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0")
|
||||||
@ -230,10 +230,12 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cache_kernels.cu"
|
"csrc/cache_kernels.cu"
|
||||||
"csrc/attention/paged_attention_v1.cu"
|
"csrc/attention/paged_attention_v1.cu"
|
||||||
"csrc/attention/paged_attention_v2.cu"
|
"csrc/attention/paged_attention_v2.cu"
|
||||||
|
"csrc/attention/merge_attn_states.cu"
|
||||||
"csrc/pos_encoding_kernels.cu"
|
"csrc/pos_encoding_kernels.cu"
|
||||||
"csrc/activation_kernels.cu"
|
"csrc/activation_kernels.cu"
|
||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
"csrc/layernorm_quant_kernels.cu"
|
"csrc/layernorm_quant_kernels.cu"
|
||||||
|
"csrc/cuda_view.cu"
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
@ -241,6 +243,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/prepare_inputs/advance_step.cu"
|
"csrc/prepare_inputs/advance_step.cu"
|
||||||
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/torch_bindings.cpp")
|
"csrc/torch_bindings.cpp")
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
@ -282,7 +285,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/custom_all_reduce.cu"
|
|
||||||
"csrc/permute_cols.cu"
|
"csrc/permute_cols.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||||
@ -461,6 +463,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
set(FP4_ARCHS)
|
set(FP4_ARCHS)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# CUTLASS MoE kernels
|
||||||
|
|
||||||
|
# The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works
|
||||||
|
# on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible
|
||||||
|
# to compile MoE kernels that use its output.
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
|
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||||
|
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1")
|
||||||
|
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||||
|
else()
|
||||||
|
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
|
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||||
|
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||||
|
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||||
|
"in CUDA target architectures")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Machete kernels
|
# Machete kernels
|
||||||
|
|
||||||
@ -580,21 +609,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
|
||||||
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}")
|
||||||
if (MARLIN_MOE_ARCHS)
|
if (MARLIN_MOE_ARCHS)
|
||||||
set(MARLIN_MOE_SRC
|
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
|
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
|
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
|
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
|
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
|
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h"
|
|
||||||
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu"
|
|
||||||
"csrc/moe/marlin_moe_ops.cu")
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# For the Marlin MOE kernels we automatically generate sources for various
|
||||||
|
# preselected input type pairs and schedules.
|
||||||
|
# Generate sources:
|
||||||
|
set(MOE_MARLIN_GEN_SCRIPT
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe/marlin_moe_wna16/generate_kernels.py)
|
||||||
|
file(MD5 ${MOE_MARLIN_GEN_SCRIPT} MOE_MARLIN_GEN_SCRIPT_HASH)
|
||||||
|
|
||||||
|
message(STATUS "Marlin MOE generation script hash: ${MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||||
|
message(STATUS "Last run Marlin MOE generate script hash: $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}")
|
||||||
|
|
||||||
|
if (NOT DEFINED CACHE{MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||||
|
OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH})
|
||||||
|
execute_process(
|
||||||
|
COMMAND ${CMAKE_COMMAND} -E env
|
||||||
|
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||||
|
${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT}
|
||||||
|
RESULT_VARIABLE moe_marlin_generation_result
|
||||||
|
OUTPUT_VARIABLE moe_marlin_generation_output
|
||||||
|
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||||
|
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log
|
||||||
|
)
|
||||||
|
|
||||||
|
if (NOT moe_marlin_generation_result EQUAL 0)
|
||||||
|
message(FATAL_ERROR "Marlin MOE generation failed."
|
||||||
|
" Result: \"${moe_marlin_generation_result}\""
|
||||||
|
"\nCheck the log for details: "
|
||||||
|
"${CMAKE_CURRENT_BINARY_DIR}/moe_marlin_generation.log")
|
||||||
|
else()
|
||||||
|
set(MOE_MARLIN_GEN_SCRIPT_HASH ${MOE_MARLIN_GEN_SCRIPT_HASH}
|
||||||
|
CACHE STRING "Last run Marlin MOE generate script hash" FORCE)
|
||||||
|
message(STATUS "Marlin MOE generation completed successfully.")
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
message(STATUS "Marlin MOE generation script has not changed, skipping generation.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
file(GLOB MOE_WNAA16_MARLIN_SRC "csrc/moe/marlin_moe_wna16/*.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${MARLIN_MOE_SRC}"
|
SRCS "${MOE_WNAA16_MARLIN_SRC}"
|
||||||
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
|
||||||
|
|
||||||
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
|
list(APPEND VLLM_MOE_EXT_SRC ${MOE_WNAA16_MARLIN_SRC})
|
||||||
|
|
||||||
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
|
||||||
else()
|
else()
|
||||||
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
|
||||||
|
|||||||
@ -1,69 +0,0 @@
|
|||||||
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
|
||||||
|
|
||||||
FROM ubuntu:22.04 AS cpu-test-1
|
|
||||||
|
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
|
||||||
|
|
||||||
ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt \
|
|
||||||
apt-get update -y \
|
|
||||||
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
|
||||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
|
||||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
|
||||||
|
|
||||||
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
|
||||||
# intel-openmp provides additional performance improvement vs. openmp
|
|
||||||
# tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects.
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install intel-openmp==2025.0.1
|
|
||||||
|
|
||||||
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so"
|
|
||||||
|
|
||||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
|
||||||
|
|
||||||
RUN pip install intel_extension_for_pytorch==2.6.0
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
|
||||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
|
|
||||||
pip install --upgrade pip && \
|
|
||||||
pip install -r requirements/build.txt
|
|
||||||
|
|
||||||
FROM cpu-test-1 AS build
|
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
--mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \
|
|
||||||
--mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \
|
|
||||||
pip install -v -r requirements/cpu.txt
|
|
||||||
|
|
||||||
COPY . .
|
|
||||||
ARG GIT_REPO_CHECK=0
|
|
||||||
RUN --mount=type=bind,source=.git,target=.git \
|
|
||||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
|
||||||
|
|
||||||
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
|
||||||
ARG VLLM_CPU_DISABLE_AVX512
|
|
||||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
--mount=type=cache,target=/root/.cache/ccache \
|
|
||||||
--mount=type=bind,source=.git,target=.git \
|
|
||||||
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \
|
|
||||||
pip install dist/*.whl && \
|
|
||||||
rm -rf dist
|
|
||||||
|
|
||||||
WORKDIR /workspace/
|
|
||||||
|
|
||||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install -e tests/vllm_test_utils
|
|
||||||
|
|
||||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
|
||||||
@ -1,29 +0,0 @@
|
|||||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
|
||||||
# to run the OpenAI compatible server.
|
|
||||||
|
|
||||||
FROM ubuntu:22.04 AS dev
|
|
||||||
|
|
||||||
RUN apt-get update -y && \
|
|
||||||
apt-get install -y \
|
|
||||||
git python3-pip \
|
|
||||||
ffmpeg libsm6 libxext6 libgl1
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
COPY . .
|
|
||||||
ARG GIT_REPO_CHECK=0
|
|
||||||
RUN --mount=type=bind,source=.git,target=.git \
|
|
||||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
|
||||||
|
|
||||||
RUN python3 -m pip install -U pip
|
|
||||||
# install build requirements
|
|
||||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements/build.txt
|
|
||||||
# build vLLM with OpenVINO backend
|
|
||||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace
|
|
||||||
|
|
||||||
COPY examples/ /workspace/examples
|
|
||||||
COPY benchmarks/ /workspace/benchmarks
|
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
|
||||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
|
||||||
@ -1,37 +0,0 @@
|
|||||||
FROM mambaorg/micromamba
|
|
||||||
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
|
||||||
USER root
|
|
||||||
|
|
||||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
|
||||||
|
|
||||||
RUN apt-get update -y && apt-get install -y git wget kmod curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
|
||||||
|
|
||||||
# Some packages in requirements/cpu are installed here
|
|
||||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
|
||||||
# Currently these may not be available for venv or pip directly
|
|
||||||
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 rust && micromamba clean --all --yes
|
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
|
||||||
ARG GIT_REPO_CHECK=0
|
|
||||||
RUN --mount=type=bind,source=.git,target=.git \
|
|
||||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
|
||||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
|
||||||
-r requirements/cpu.txt \
|
|
||||||
xformers uvloop==0.20.0
|
|
||||||
|
|
||||||
RUN --mount=type=bind,source=.git,target=.git \
|
|
||||||
VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
|
||||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
|
||||||
|
|
||||||
WORKDIR /workspace/
|
|
||||||
|
|
||||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
|
||||||
|
|
||||||
ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"]
|
|
||||||
28
README.md
28
README.md
@ -10,17 +10,24 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
*Latest News* 🔥
|
---
|
||||||
|
|
||||||
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit#slide=id.g33fb1ff286e_0_29).
|
*Latest News* 🔥
|
||||||
|
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||||
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||||
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||||
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
|
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
|
||||||
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Previous News</summary>
|
||||||
|
|
||||||
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||||
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||||
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
|
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
|
||||||
@ -34,8 +41,9 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
---
|
</details>
|
||||||
|
|
||||||
|
---
|
||||||
## About
|
## About
|
||||||
|
|
||||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||||
@ -90,7 +98,7 @@ Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
|||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome and value any contributions and collaborations.
|
We welcome and value any contributions and collaborations.
|
||||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/stable/contributing/overview.html) for how to get involved.
|
||||||
|
|
||||||
## Sponsors
|
## Sponsors
|
||||||
|
|
||||||
@ -113,6 +121,7 @@ Compute Resources:
|
|||||||
- Databricks
|
- Databricks
|
||||||
- DeepInfra
|
- DeepInfra
|
||||||
- Google Cloud
|
- Google Cloud
|
||||||
|
- Intel
|
||||||
- Lambda Lab
|
- Lambda Lab
|
||||||
- Nebius
|
- Nebius
|
||||||
- Novita AI
|
- Novita AI
|
||||||
@ -143,10 +152,11 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
|||||||
|
|
||||||
## Contact Us
|
## Contact Us
|
||||||
|
|
||||||
- For technical questions and feature requests, please use GitHub issues or discussions.
|
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
|
||||||
- For discussing with fellow users and coordinating contributions and development, please use Slack.
|
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||||
- For security disclosures, please use GitHub's security advisory feature.
|
- coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||||
- For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu.
|
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||||
|
- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu)
|
||||||
|
|
||||||
## Media Kit
|
## Media Kit
|
||||||
|
|
||||||
|
|||||||
@ -41,29 +41,39 @@ become available.
|
|||||||
<td><code>synthetic</code></td>
|
<td><code>synthetic</code></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td><strong>HuggingFace</strong></td>
|
<td><strong>HuggingFace-VisionArena</strong></td>
|
||||||
<td style="text-align: center;">✅</td>
|
<td style="text-align: center;">✅</td>
|
||||||
<td style="text-align: center;">🟡</td>
|
<td style="text-align: center;">✅</td>
|
||||||
<td>Specify your dataset path on HuggingFace</td>
|
<td><code>lmarena-ai/VisionArena-Chat</code></td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td><strong>VisionArena</strong></td>
|
<td><strong>HuggingFace-InstructCoder</strong></td>
|
||||||
<td style="text-align: center;">✅</td>
|
<td style="text-align: center;">✅</td>
|
||||||
<td style="text-align: center;">✅</td>
|
<td style="text-align: center;">✅</td>
|
||||||
<td><code>lmarena-ai/vision-arena-bench-v0.1</code> (a HuggingFace dataset)</td>
|
<td><code>likaixin/InstructCoder</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><strong>HuggingFace-AIMO</strong></td>
|
||||||
|
<td style="text-align: center;">✅</td>
|
||||||
|
<td style="text-align: center;">✅</td>
|
||||||
|
<td><code>AI-MO/aimo-validation-aime</code> , <code>AI-MO/NuminaMath-1.5</code>, <code>AI-MO/NuminaMath-CoT</code></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><strong>HuggingFace-Other</strong></td>
|
||||||
|
<td style="text-align: center;">✅</td>
|
||||||
|
<td style="text-align: center;">✅</td>
|
||||||
|
<td><code>lmms-lab/LLaVA-OneVision-Data</code>, <code>Aeala/ShareGPT_Vicuna_unfiltered</code></td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody>
|
</tbody>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
✅: supported
|
✅: supported
|
||||||
|
|
||||||
|
🟡: Partial support
|
||||||
|
|
||||||
🚧: to be supported
|
🚧: to be supported
|
||||||
|
|
||||||
🟡: Partial support. Currently, HuggingFaceDataset only supports dataset formats
|
**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`
|
||||||
similar to `lmms-lab/LLaVA-OneVision-Data`. If you need support for other dataset
|
|
||||||
formats, please consider contributing.
|
|
||||||
|
|
||||||
**Note**: VisionArena’s `dataset-name` should be set to `hf`
|
|
||||||
|
|
||||||
---
|
---
|
||||||
## Example - Online Benchmark
|
## Example - Online Benchmark
|
||||||
@ -71,8 +81,7 @@ formats, please consider contributing.
|
|||||||
First start serving your model
|
First start serving your model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests
|
||||||
vllm serve ${MODEL_NAME} --disable-log-requests
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Then run the benchmarking script
|
Then run the benchmarking script
|
||||||
@ -80,12 +89,13 @@ Then run the benchmarking script
|
|||||||
```bash
|
```bash
|
||||||
# download dataset
|
# download dataset
|
||||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
NUM_PROMPTS=10
|
--backend vllm \
|
||||||
BACKEND="vllm"
|
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||||
DATASET_NAME="sharegpt"
|
--endpoint /v1/completions \
|
||||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
--dataset-name sharegpt \
|
||||||
python3 vllm/benchmarks/benchmark_serving.py --backend ${BACKEND} --model ${MODEL_NAME} --endpoint /v1/completions --dataset-name ${DATASET_NAME} --dataset-path ${DATASET_PATH} --num-prompts ${NUM_PROMPTS}
|
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--num-prompts 10
|
||||||
```
|
```
|
||||||
|
|
||||||
If successful, you will see the following output
|
If successful, you will see the following output
|
||||||
@ -122,37 +132,105 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
|||||||
```
|
```
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
|
||||||
NUM_PROMPTS=10
|
|
||||||
BACKEND="openai-chat"
|
|
||||||
DATASET_NAME="hf"
|
|
||||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
|
||||||
DATASET_SPLIT='train'
|
|
||||||
|
|
||||||
python3 vllm/benchmarks/benchmark_serving.py \
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
--backend "${BACKEND}" \
|
--backend openai-chat \
|
||||||
--model "${MODEL_NAME}" \
|
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||||
--endpoint "/v1/chat/completions" \
|
--endpoint /v1/chat/completions \
|
||||||
--dataset-name "${DATASET_NAME}" \
|
--dataset-name hf \
|
||||||
--dataset-path "${DATASET_PATH}" \
|
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||||
--hf-split "${DATASET_SPLIT}" \
|
--hf-split train \
|
||||||
--num-prompts "${NUM_PROMPTS}"
|
--num-prompts 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
### InstructCoder Benchmark with Speculative Decoding
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||||
|
--speculative-model "[ngram]" \
|
||||||
|
--ngram_prompt_lookup_min 2 \
|
||||||
|
--ngram-prompt-lookup-max 5 \
|
||||||
|
--num_speculative_tokens 5
|
||||||
|
```
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
python3 benchmarks/benchmark_serving.py \
|
||||||
|
--model meta-llama/Meta-Llama-3-8B-Instruct \
|
||||||
|
--dataset-name hf \
|
||||||
|
--dataset-path likaixin/InstructCoder \
|
||||||
|
--num-prompts 2048
|
||||||
|
```
|
||||||
|
|
||||||
|
### Other HuggingFaceDataset Examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||||
|
```
|
||||||
|
|
||||||
|
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
|
--backend openai-chat \
|
||||||
|
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
--endpoint /v1/chat/completions \
|
||||||
|
--dataset-name hf \
|
||||||
|
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||||
|
--hf-split train \
|
||||||
|
--hf-subset "chart2text(cauldron)" \
|
||||||
|
--num-prompts 10
|
||||||
|
```
|
||||||
|
|
||||||
|
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
|
--backend openai-chat \
|
||||||
|
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
--endpoint /v1/chat/completions \
|
||||||
|
--dataset-name hf \
|
||||||
|
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||||
|
--hf-split train \
|
||||||
|
--num-prompts 10
|
||||||
|
```
|
||||||
|
|
||||||
|
**`AI-MO/aimo-validation-aime`**
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
|
--model Qwen/QwQ-32B \
|
||||||
|
--dataset-name hf \
|
||||||
|
--dataset-path AI-MO/aimo-validation-aime \
|
||||||
|
--num-prompts 10 \
|
||||||
|
--seed 42
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running With Sampling Parameters
|
||||||
|
|
||||||
|
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||||
|
parameters can be specified. Example client command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vllm/benchmarks/benchmark_serving.py \
|
||||||
|
--backend vllm \
|
||||||
|
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||||
|
--endpoint /v1/completions \
|
||||||
|
--dataset-name sharegpt \
|
||||||
|
--dataset-path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--top-k 10 \
|
||||||
|
--top-p 0.9 \
|
||||||
|
--temperature 0.5 \
|
||||||
|
--num-prompts 10
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
## Example - Offline Throughput Benchmark
|
## Example - Offline Throughput Benchmark
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
MODEL_NAME="NousResearch/Hermes-3-Llama-3.1-8B"
|
|
||||||
NUM_PROMPTS=10
|
|
||||||
DATASET_NAME="sonnet"
|
|
||||||
DATASET_PATH="vllm/benchmarks/sonnet.txt"
|
|
||||||
|
|
||||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||||
--model "${MODEL_NAME}" \
|
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||||
--dataset-name "${DATASET_NAME}" \
|
--dataset-name sonnet \
|
||||||
--dataset-path "${DATASET_PATH}" \
|
--dataset-path vllm/benchmarks/sonnet.txt \
|
||||||
--num-prompts "${NUM_PROMPTS}"
|
--num-prompts 10
|
||||||
```
|
```
|
||||||
|
|
||||||
If successful, you will see the following output
|
If successful, you will see the following output
|
||||||
@ -166,19 +244,13 @@ Total num output tokens: 1500
|
|||||||
### VisionArena Benchmark for Vision Language Models
|
### VisionArena Benchmark for Vision Language Models
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
MODEL_NAME="Qwen/Qwen2-VL-7B-Instruct"
|
|
||||||
NUM_PROMPTS=10
|
|
||||||
DATASET_NAME="hf"
|
|
||||||
DATASET_PATH="lmarena-ai/vision-arena-bench-v0.1"
|
|
||||||
DATASET_SPLIT="train"
|
|
||||||
|
|
||||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||||
--model "${MODEL_NAME}" \
|
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||||
--backend "vllm-chat" \
|
--backend vllm-chat \
|
||||||
--dataset-name "${DATASET_NAME}" \
|
--dataset-name hf \
|
||||||
--dataset-path "${DATASET_PATH}" \
|
--dataset-path lmarena-ai/VisionArena-Chat \
|
||||||
--num-prompts "${NUM_PROMPTS}" \
|
--num-prompts 1000 \
|
||||||
--hf-split "${DATASET_SPLIT}"
|
--hf-split train
|
||||||
```
|
```
|
||||||
|
|
||||||
The `num prompt tokens` now includes image token counts
|
The `num prompt tokens` now includes image token counts
|
||||||
@ -189,29 +261,83 @@ Total num prompt tokens: 14527
|
|||||||
Total num output tokens: 1280
|
Total num output tokens: 1280
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### InstructCoder Benchmark with Speculative Decoding
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||||
|
VLLM_USE_V1=1 \
|
||||||
|
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||||
|
--dataset-name=hf \
|
||||||
|
--dataset-path=likaixin/InstructCoder \
|
||||||
|
--model=meta-llama/Meta-Llama-3-8B-Instruct \
|
||||||
|
--input-len=1000 \
|
||||||
|
--output-len=100 \
|
||||||
|
--num-prompts=2048 \
|
||||||
|
--async-engine \
|
||||||
|
--speculative-model="[ngram]" \
|
||||||
|
--ngram_prompt_lookup_min=2 \
|
||||||
|
--ngram-prompt-lookup-max=5 \
|
||||||
|
--num_speculative_tokens=5
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
Throughput: 104.77 requests/s, 23836.22 total tokens/s, 10477.10 output tokens/s
|
||||||
|
Total num prompt tokens: 261136
|
||||||
|
Total num output tokens: 204800
|
||||||
|
```
|
||||||
|
|
||||||
|
### Other HuggingFaceDataset Examples
|
||||||
|
|
||||||
|
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||||
|
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
--backend vllm-chat \
|
||||||
|
--dataset-name hf \
|
||||||
|
--dataset-path lmms-lab/LLaVA-OneVision-Data \
|
||||||
|
--hf-split train \
|
||||||
|
--hf-subset "chart2text(cauldron)" \
|
||||||
|
--num-prompts 10
|
||||||
|
```
|
||||||
|
|
||||||
|
**`Aeala/ShareGPT_Vicuna_unfiltered`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||||
|
--model Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
--backend vllm-chat \
|
||||||
|
--dataset-name hf \
|
||||||
|
--dataset-path Aeala/ShareGPT_Vicuna_unfiltered \
|
||||||
|
--hf-split train \
|
||||||
|
--num-prompts 10
|
||||||
|
```
|
||||||
|
|
||||||
|
**`AI-MO/aimo-validation-aime`**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 benchmarks/benchmark_throughput.py \
|
||||||
|
--model Qwen/QwQ-32B \
|
||||||
|
--backend vllm \
|
||||||
|
--dataset-name hf \
|
||||||
|
--dataset-path AI-MO/aimo-validation-aime \
|
||||||
|
--hf-split train \
|
||||||
|
--num-prompts 10
|
||||||
|
```
|
||||||
|
|
||||||
### Benchmark with LoRA Adapters
|
### Benchmark with LoRA Adapters
|
||||||
|
|
||||||
``` bash
|
``` bash
|
||||||
# download dataset
|
# download dataset
|
||||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
MODEL_NAME="meta-llama/Llama-2-7b-hf"
|
|
||||||
BACKEND="vllm"
|
|
||||||
DATASET_NAME="sharegpt"
|
|
||||||
DATASET_PATH="<your data path>/ShareGPT_V3_unfiltered_cleaned_split.json"
|
|
||||||
NUM_PROMPTS=10
|
|
||||||
MAX_LORAS=2
|
|
||||||
MAX_LORA_RANK=8
|
|
||||||
ENABLE_LORA="--enable-lora"
|
|
||||||
LORA_PATH="yard1/llama-2-7b-sql-lora-test"
|
|
||||||
|
|
||||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||||
--model "${MODEL_NAME}" \
|
--model meta-llama/Llama-2-7b-hf \
|
||||||
--backend "${BACKEND}" \
|
--backend vllm \
|
||||||
--dataset_path "${DATASET_PATH}" \
|
--dataset_path <your data path>/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
--dataset_name "${DATASET_NAME}" \
|
--dataset_name sharegpt \
|
||||||
--num-prompts "${NUM_PROMPTS}" \
|
--num-prompts 10 \
|
||||||
--max-loras "${MAX_LORAS}" \
|
--max-loras 2 \
|
||||||
--max-lora-rank "${MAX_LORA_RANK}" \
|
--max-lora-rank 8 \
|
||||||
${ENABLE_LORA} \
|
--enable-lora \
|
||||||
--lora-path "${LORA_PATH}"
|
--lora-path yard1/llama-2-7b-sql-lora-test
|
||||||
```
|
```
|
||||||
|
|||||||
@ -63,7 +63,7 @@ async def async_request_tgi(
|
|||||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||||
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||||
"truncate": request_func_input.prompt_len,
|
"truncate": request_func_input.prompt_len,
|
||||||
# TGI does not accept ignore_eos flag.
|
"ignore_eos_token": request_func_input.ignore_eos,
|
||||||
}
|
}
|
||||||
payload = {
|
payload = {
|
||||||
"inputs": request_func_input.prompt,
|
"inputs": request_func_input.prompt,
|
||||||
@ -71,6 +71,10 @@ async def async_request_tgi(
|
|||||||
}
|
}
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
if request_func_input.ignore_eos:
|
||||||
|
output.output_tokens = request_func_input.output_len
|
||||||
|
else:
|
||||||
|
output.output_tokens = None
|
||||||
|
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
@ -215,7 +219,15 @@ async def async_request_deepspeed_mii(
|
|||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
parsed_resp = await response.json()
|
parsed_resp = await response.json()
|
||||||
output.latency = time.perf_counter() - st
|
output.latency = time.perf_counter() - st
|
||||||
output.generated_text = parsed_resp["text"][0]
|
if "choices" in parsed_resp:
|
||||||
|
output.generated_text = parsed_resp["choices"][0][
|
||||||
|
"text"]
|
||||||
|
elif "text" in parsed_resp:
|
||||||
|
output.generated_text = parsed_resp["text"][0]
|
||||||
|
else:
|
||||||
|
output.error = ("Unexpected response format: "
|
||||||
|
"neither 'choices' nor 'text' found")
|
||||||
|
output.success = False
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
@ -485,3 +497,9 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
"scalellm": async_request_openai_completions,
|
"scalellm": async_request_openai_completions,
|
||||||
"sglang": async_request_openai_completions,
|
"sglang": async_request_openai_completions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OPENAI_COMPATIBLE_BACKENDS = [
|
||||||
|
k for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||||
|
if v in (async_request_openai_completions,
|
||||||
|
async_request_openai_chat_completions)
|
||||||
|
]
|
||||||
|
|||||||
@ -17,12 +17,14 @@ SampleRequest instances, similar to the approach used in ShareGPT.
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import random
|
import random
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import Any, Optional, Union
|
from io import BytesIO
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -35,6 +37,8 @@ from vllm.lora.utils import get_adapter_absolute_path
|
|||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Data Classes
|
# Data Classes
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -61,9 +65,6 @@ class SampleRequest:
|
|||||||
class BenchmarkDataset(ABC):
|
class BenchmarkDataset(ABC):
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
|
|
||||||
# num_requests has default 1000 in both the benchmark_serving.py and
|
|
||||||
# benchmark_throughput.py
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_path: Optional[str] = None,
|
dataset_path: Optional[str] = None,
|
||||||
@ -90,8 +91,8 @@ class BenchmarkDataset(ABC):
|
|||||||
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Transform a prompt and optional multimodal content into a chat format.
|
Transform a prompt and optional multimodal content into a chat format.
|
||||||
This method is used for chat models that expect a specific
|
This method is used for chat models that expect a specific conversation
|
||||||
conversation format.
|
format.
|
||||||
"""
|
"""
|
||||||
content = [{"text": prompt, "type": "text"}]
|
content = [{"text": prompt, "type": "text"}]
|
||||||
if mm_content is not None:
|
if mm_content is not None:
|
||||||
@ -101,10 +102,10 @@ class BenchmarkDataset(ABC):
|
|||||||
def load_data(self) -> None:
|
def load_data(self) -> None:
|
||||||
"""
|
"""
|
||||||
Load data from the dataset path into self.data.
|
Load data from the dataset path into self.data.
|
||||||
|
|
||||||
This method must be overridden by subclasses since the method to load
|
This method must be overridden by subclasses since the method to load
|
||||||
data will vary depending on the dataset format and source.
|
data will vary depending on the dataset format and source.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: If a subclass does not implement this method.
|
NotImplementedError: If a subclass does not implement this method.
|
||||||
"""
|
"""
|
||||||
@ -121,18 +122,18 @@ class BenchmarkDataset(ABC):
|
|||||||
"""
|
"""
|
||||||
Optionally select a random LoRA request and return its associated
|
Optionally select a random LoRA request and return its associated
|
||||||
tokenizer.
|
tokenizer.
|
||||||
|
|
||||||
This method is used when LoRA parameters are provided. It randomly
|
This method is used when LoRA parameters are provided. It randomly
|
||||||
selects a LoRA based on max_loras and retrieves a cached tokenizer for
|
selects a LoRA based on max_loras and retrieves a cached tokenizer for
|
||||||
that LoRA if available. Otherwise, it returns the base tokenizer.
|
that LoRA if available. Otherwise, it returns the base tokenizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
|
tokenizer (PreTrainedTokenizerBase): The base tokenizer to use if no
|
||||||
LoRA is selected. max_loras (Optional[int]): The maximum number of
|
LoRA is selected. max_loras (Optional[int]): The maximum number of
|
||||||
LoRAs available. If None, LoRA is not used. lora_path
|
LoRAs available. If None, LoRA is not used. lora_path
|
||||||
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
|
(Optional[str]): Path to the LoRA parameters on disk. If None, LoRA
|
||||||
is not used.
|
is not used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
|
tuple[Optional[LoRARequest], AnyTokenizer]: A tuple where the first
|
||||||
element is a LoRARequest (or None if not applicable) and the second
|
element is a LoRARequest (or None if not applicable) and the second
|
||||||
@ -160,21 +161,39 @@ class BenchmarkDataset(ABC):
|
|||||||
num_requests: int) -> list[SampleRequest]:
|
num_requests: int) -> list[SampleRequest]:
|
||||||
"""
|
"""
|
||||||
Abstract method to generate sample requests from the dataset.
|
Abstract method to generate sample requests from the dataset.
|
||||||
|
|
||||||
Subclasses must override this method to implement dataset-specific logic
|
Subclasses must override this method to implement dataset-specific logic
|
||||||
for generating a list of SampleRequest objects.
|
for generating a list of SampleRequest objects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
|
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
|
||||||
for processing the dataset's text.
|
for processing the dataset's text.
|
||||||
num_requests (int): The number of sample requests to generate.
|
num_requests (int): The number of sample requests to generate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[SampleRequest]: A list of sample requests generated from the
|
list[SampleRequest]: A list of sample requests generated from the
|
||||||
dataset.
|
dataset.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("sample must be implemented in subclasses.")
|
raise NotImplementedError("sample must be implemented in subclasses.")
|
||||||
|
|
||||||
|
def maybe_oversample_requests(self, requests: list[SampleRequest],
|
||||||
|
num_requests: int) -> None:
|
||||||
|
"""
|
||||||
|
Oversamples the list of requests if its size is less than the desired
|
||||||
|
number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests (List[SampleRequest]): The current list of sampled
|
||||||
|
requests. num_requests (int): The target number of requests.
|
||||||
|
"""
|
||||||
|
if len(requests) < num_requests:
|
||||||
|
random.seed(self.random_seed)
|
||||||
|
additional = random.choices(requests,
|
||||||
|
k=num_requests - len(requests))
|
||||||
|
requests.extend(additional)
|
||||||
|
logger.info("Oversampled requests to reach %d total samples.",
|
||||||
|
num_requests)
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Utility Functions and Global Caches
|
# Utility Functions and Global Caches
|
||||||
@ -221,21 +240,24 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
|||||||
"""
|
"""
|
||||||
Process a single image input and return a multimedia content dictionary.
|
Process a single image input and return a multimedia content dictionary.
|
||||||
|
|
||||||
For a PIL.Image.Image input:
|
Supports three input types:
|
||||||
- Converts the image to RGB.
|
|
||||||
- Saves the image as a JPEG in-memory.
|
|
||||||
- Encodes the JPEG data as a base64 string.
|
|
||||||
- Returns a dictionary with the image as a base64 data URL.
|
|
||||||
|
|
||||||
For a string input:
|
1. Dictionary with raw image bytes: - Expects a dict with a 'bytes' key
|
||||||
- Treats the string as a URL or file path.
|
containing raw image data. - Loads the bytes as a PIL.Image.Image.
|
||||||
- Prepends "file://" if the string doesn't start with "http://" or
|
|
||||||
"file://".
|
2. PIL.Image.Image input: - Converts the image to RGB. - Saves the image as
|
||||||
- Returns a dictionary with the image URL.
|
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
|
||||||
|
a dictionary with the image as a base64 data URL.
|
||||||
|
|
||||||
|
3. String input: - Treats the string as a URL or local file path. -
|
||||||
|
Prepends "file://" if the string doesn't start with "http://" or
|
||||||
|
"file://". - Returns a dictionary with the image URL.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input is neither a PIL.Image.Image nor a string.
|
ValueError: If the input is not a supported type.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(image, dict) and 'bytes' in image:
|
||||||
|
image = Image.open(BytesIO(image['bytes']))
|
||||||
if isinstance(image, Image.Image):
|
if isinstance(image, Image.Image):
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
with io.BytesIO() as image_data:
|
with io.BytesIO() as image_data:
|
||||||
@ -254,8 +276,8 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
|||||||
("http://", "file://")) else f"file://{image}")
|
("http://", "file://")) else f"file://{image}")
|
||||||
return {"type": "image_url", "image_url": {"url": image_url}}
|
return {"type": "image_url", "image_url": {"url": image_url}}
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image"
|
||||||
f"Invalid image input {image}. Must be a PIL.Image.Image or str.")
|
" or str or dictionary with raw image bytes.")
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
@ -266,7 +288,7 @@ def process_image(image: Any) -> Mapping[str, Any]:
|
|||||||
class RandomDataset(BenchmarkDataset):
|
class RandomDataset(BenchmarkDataset):
|
||||||
# Default values copied from benchmark_serving.py for the random dataset.
|
# Default values copied from benchmark_serving.py for the random dataset.
|
||||||
DEFAULT_PREFIX_LEN = 0
|
DEFAULT_PREFIX_LEN = 0
|
||||||
DEFAULT_RANGE_RATIO = 1.0
|
DEFAULT_RANGE_RATIO = 0.0
|
||||||
DEFAULT_INPUT_LEN = 1024
|
DEFAULT_INPUT_LEN = 1024
|
||||||
DEFAULT_OUTPUT_LEN = 128
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
|
|
||||||
@ -276,28 +298,42 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
self,
|
||||||
num_requests: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
num_requests: int,
|
||||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||||
input_len: int = DEFAULT_INPUT_LEN,
|
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
input_len: int = DEFAULT_INPUT_LEN,
|
||||||
**kwargs) -> list[SampleRequest]:
|
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
|
# Enforce range_ratio < 1
|
||||||
|
assert range_ratio < 1.0, (
|
||||||
|
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
|
||||||
|
)
|
||||||
|
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
prefix_token_ids = (np.random.randint(
|
prefix_token_ids = (np.random.randint(
|
||||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
||||||
|
|
||||||
input_low = int(input_len * range_ratio)
|
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
||||||
output_low = int(output_len * range_ratio)
|
input_low = int(input_len * (1 - range_ratio))
|
||||||
|
input_high = int(input_len * (1 + range_ratio))
|
||||||
|
output_low = int(output_len * (1 - range_ratio))
|
||||||
|
output_high = int(output_len * (1 + range_ratio))
|
||||||
|
|
||||||
|
# Add logging for debugging
|
||||||
|
logger.info("Sampling input_len from [%s, %s]", input_low, input_high)
|
||||||
|
logger.info("Sampling output_len from [%s, %s]", output_low,
|
||||||
|
output_high)
|
||||||
|
|
||||||
input_lens = np.random.randint(input_low,
|
input_lens = np.random.randint(input_low,
|
||||||
input_len + 1,
|
input_high + 1,
|
||||||
size=num_requests)
|
size=num_requests)
|
||||||
output_lens = np.random.randint(output_low,
|
output_lens = np.random.randint(output_low,
|
||||||
output_len + 1,
|
output_high + 1,
|
||||||
size=num_requests)
|
size=num_requests)
|
||||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
||||||
|
|
||||||
@ -346,20 +382,24 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
random.seed(self.random_seed)
|
random.seed(self.random_seed)
|
||||||
random.shuffle(self.data)
|
random.shuffle(self.data)
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
self,
|
||||||
num_requests: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
lora_path: Optional[str] = None,
|
num_requests: int,
|
||||||
max_loras: Optional[int] = None,
|
lora_path: Optional[str] = None,
|
||||||
output_len: Optional[int] = None,
|
max_loras: Optional[int] = None,
|
||||||
enable_multimodal_chat: bool = False,
|
output_len: Optional[int] = None,
|
||||||
**kwargs) -> list:
|
enable_multimodal_chat: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list:
|
||||||
samples: list = []
|
samples: list = []
|
||||||
for entry in self.data:
|
for entry in self.data:
|
||||||
if len(samples) >= num_requests:
|
if len(samples) >= num_requests:
|
||||||
break
|
break
|
||||||
prompt, completion = entry["conversations"][0]["value"],\
|
prompt, completion = (
|
||||||
entry["conversations"][1]["value"]
|
entry["conversations"][0]["value"],
|
||||||
|
entry["conversations"][1]["value"],
|
||||||
|
)
|
||||||
|
|
||||||
lora_request, tokenizer = self.get_random_lora_request(
|
lora_request, tokenizer = self.get_random_lora_request(
|
||||||
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path)
|
||||||
@ -383,6 +423,7 @@ class ShareGPTDataset(BenchmarkDataset):
|
|||||||
expected_output_len=new_output_len,
|
expected_output_len=new_output_len,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
))
|
))
|
||||||
|
self.maybe_oversample_requests(samples, num_requests)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@ -415,19 +456,20 @@ class SonnetDataset(BenchmarkDataset):
|
|||||||
with open(self.dataset_path, encoding="utf-8") as f:
|
with open(self.dataset_path, encoding="utf-8") as f:
|
||||||
self.data = f.readlines()
|
self.data = f.readlines()
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer,
|
self,
|
||||||
num_requests: int,
|
tokenizer,
|
||||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
num_requests: int,
|
||||||
input_len: int = DEFAULT_INPUT_LEN,
|
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
input_len: int = DEFAULT_INPUT_LEN,
|
||||||
return_prompt_formatted: bool = False,
|
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||||
**kwargs) -> list:
|
return_prompt_formatted: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> list:
|
||||||
# Calculate average token length for a poem line.
|
# Calculate average token length for a poem line.
|
||||||
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
tokenized_lines = [tokenizer(line).input_ids for line in self.data]
|
||||||
avg_len = sum(len(tokens)
|
avg_len = sum(len(tokens)
|
||||||
for tokens in \
|
for tokens in tokenized_lines) / len(tokenized_lines)
|
||||||
tokenized_lines) / len(tokenized_lines)
|
|
||||||
|
|
||||||
# Build the base prompt.
|
# Build the base prompt.
|
||||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||||
@ -443,11 +485,11 @@ class SonnetDataset(BenchmarkDataset):
|
|||||||
|
|
||||||
# Determine how many poem lines to use.
|
# Determine how many poem lines to use.
|
||||||
num_input_lines = round((input_len - base_offset) / avg_len)
|
num_input_lines = round((input_len - base_offset) / avg_len)
|
||||||
num_prefix_lines = round((prefix_len - base_offset) / avg_len)
|
num_prefix_lines = max(round((prefix_len - base_offset) / avg_len), 0)
|
||||||
prefix_lines = self.data[:num_prefix_lines]
|
prefix_lines = self.data[:num_prefix_lines]
|
||||||
|
|
||||||
samples = []
|
samples = []
|
||||||
for _ in range(num_requests):
|
while len(samples) < num_requests:
|
||||||
extra_lines = random.choices(self.data,
|
extra_lines = random.choices(self.data,
|
||||||
k=num_input_lines - num_prefix_lines)
|
k=num_input_lines - num_prefix_lines)
|
||||||
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}"
|
||||||
@ -455,13 +497,14 @@ class SonnetDataset(BenchmarkDataset):
|
|||||||
prompt_formatted = tokenizer.apply_chat_template(
|
prompt_formatted = tokenizer.apply_chat_template(
|
||||||
msg, add_generation_prompt=True, tokenize=False)
|
msg, add_generation_prompt=True, tokenize=False)
|
||||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||||
samples.append(
|
if prompt_len <= input_len:
|
||||||
SampleRequest(
|
samples.append(
|
||||||
prompt=prompt_formatted
|
SampleRequest(
|
||||||
if return_prompt_formatted else prompt,
|
prompt=prompt_formatted
|
||||||
prompt_len=prompt_len,
|
if return_prompt_formatted else prompt,
|
||||||
expected_output_len=output_len,
|
prompt_len=prompt_len,
|
||||||
))
|
expected_output_len=output_len,
|
||||||
|
))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@ -506,12 +549,14 @@ class BurstGPTDataset(BenchmarkDataset):
|
|||||||
# Convert the dataframe to a list of lists.
|
# Convert the dataframe to a list of lists.
|
||||||
return data.values.tolist()
|
return data.values.tolist()
|
||||||
|
|
||||||
def sample(self,
|
def sample(
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
self,
|
||||||
num_requests: int,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_loras: Optional[int] = None,
|
num_requests: int,
|
||||||
lora_path: Optional[str] = None,
|
max_loras: Optional[int] = None,
|
||||||
**kwargs) -> list[SampleRequest]:
|
lora_path: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
samples = []
|
samples = []
|
||||||
data = self._sample_loaded_data(num_requests=num_requests)
|
data = self._sample_loaded_data(num_requests=num_requests)
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
@ -535,49 +580,47 @@ class BurstGPTDataset(BenchmarkDataset):
|
|||||||
|
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# HuggingFace Dataset Implementation
|
# HuggingFace Dataset Base Implementation
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceDataset(BenchmarkDataset):
|
class HuggingFaceDataset(BenchmarkDataset):
|
||||||
"""
|
"""Base class for datasets hosted on HuggingFace."""
|
||||||
Dataset class for processing a HuggingFace dataset with conversation data
|
|
||||||
and optional images.
|
SUPPORTED_DATASET_PATHS: Union[set[str], dict[str, Callable]] = set()
|
||||||
"""
|
|
||||||
DEFAULT_NUM_REQUESTS = 1000
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
dataset_path: str,
|
||||||
dataset_split: str,
|
dataset_split: str,
|
||||||
dataset_subset: Optional[str] = None,
|
dataset_subset: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(dataset_path=dataset_path, **kwargs)
|
||||||
|
|
||||||
self.dataset_split = dataset_split
|
self.dataset_split = dataset_split
|
||||||
self.dataset_subset = dataset_subset
|
self.dataset_subset = dataset_subset
|
||||||
|
|
||||||
self.load_data()
|
self.load_data()
|
||||||
|
|
||||||
def load_data(self) -> None:
|
def load_data(self) -> None:
|
||||||
if not self.dataset_path:
|
"""Load data from HuggingFace datasets."""
|
||||||
raise ValueError("dataset_path must be provided for loading data.")
|
|
||||||
|
|
||||||
self.data = load_dataset(
|
self.data = load_dataset(
|
||||||
self.dataset_path,
|
self.dataset_path,
|
||||||
name=self.dataset_subset,
|
name=self.dataset_subset,
|
||||||
split=self.dataset_split,
|
split=self.dataset_split,
|
||||||
streaming=True,
|
streaming=True,
|
||||||
)
|
)
|
||||||
if self.data.features is None or "conversations" \
|
self.data = self.data.shuffle(seed=self.random_seed)
|
||||||
not in self.data.features:
|
|
||||||
raise ValueError(
|
|
||||||
"HuggingFaceDataset currently only supports datasets with "
|
# -----------------------------------------------------------------------------
|
||||||
"a 'conversations' column like lmms-lab/LLaVA-OneVision-Data. "
|
# Conversation Dataset Implementation
|
||||||
"Please consider contributing if you would like to add "
|
# -----------------------------------------------------------------------------
|
||||||
"support for additional dataset formats.")
|
|
||||||
# Shuffle and filter examples with at least 2 conversations.
|
|
||||||
self.data = self.data.shuffle(seed=self.random_seed).filter(
|
class ConversationDataset(HuggingFaceDataset):
|
||||||
lambda x: len(x["conversations"]) >= 2)
|
"""Dataset for conversation data with multimodal support."""
|
||||||
|
SUPPORTED_DATASET_PATHS = {
|
||||||
|
'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered'
|
||||||
|
}
|
||||||
|
|
||||||
def sample(self,
|
def sample(self,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
@ -585,10 +628,13 @@ class HuggingFaceDataset(BenchmarkDataset):
|
|||||||
output_len: Optional[int] = None,
|
output_len: Optional[int] = None,
|
||||||
enable_multimodal_chat: bool = False,
|
enable_multimodal_chat: bool = False,
|
||||||
**kwargs) -> list:
|
**kwargs) -> list:
|
||||||
|
# Filter examples with at least 2 conversations
|
||||||
|
filtered_data = self.data.filter(
|
||||||
|
lambda x: len(x["conversations"]) >= 2)
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
dynamic_output = output_len is None
|
dynamic_output = output_len is None
|
||||||
|
|
||||||
for item in self.data:
|
for item in filtered_data:
|
||||||
if len(sampled_requests) >= num_requests:
|
if len(sampled_requests) >= num_requests:
|
||||||
break
|
break
|
||||||
conv = item["conversations"]
|
conv = item["conversations"]
|
||||||
@ -618,6 +664,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
|||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=mm_content,
|
multi_modal_data=mm_content,
|
||||||
))
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
@ -632,44 +679,32 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_OUTPUT_LEN = 128
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
DEFAULT_NUM_REQUESTS = 1000
|
SUPPORTED_DATASET_PATHS = {
|
||||||
VISION_ARENA_DATASET_PATH = "lmarena-ai/vision-arena-bench-v0.1"
|
"lmarena-ai/VisionArena-Chat":
|
||||||
|
lambda x: x["conversation"][0][0]["content"],
|
||||||
|
"lmarena-ai/vision-arena-bench-v0.1":
|
||||||
|
lambda x: x["turns"][0][0]["content"]
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def sample(
|
||||||
self,
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
output_len: Optional[int] = None,
|
||||||
|
enable_multimodal_chat: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> list:
|
||||||
super().__init__(**kwargs)
|
|
||||||
if self.dataset_path != self.VISION_ARENA_DATASET_PATH:
|
|
||||||
raise ValueError(f"Only support Vision Arena dataset.\
|
|
||||||
This data path {self.dataset_path} is not valid.")
|
|
||||||
if self.dataset_subset is None and self.dataset_split != "train":
|
|
||||||
raise ValueError("Dataset split must be 'train'.")
|
|
||||||
|
|
||||||
self.load_data()
|
|
||||||
|
|
||||||
def load_data(self) -> None:
|
|
||||||
dataset = load_dataset(
|
|
||||||
self.dataset_path,
|
|
||||||
name=self.dataset_subset,
|
|
||||||
split=self.dataset_split,
|
|
||||||
streaming=True,
|
|
||||||
)
|
|
||||||
self.data = dataset.shuffle(seed=self.random_seed)
|
|
||||||
|
|
||||||
def sample(self,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
num_requests: int,
|
|
||||||
output_len: Optional[int] = None,
|
|
||||||
enable_multimodal_chat: bool = False,
|
|
||||||
**kwargs) -> list:
|
|
||||||
output_len = (output_len
|
output_len = (output_len
|
||||||
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
sampled_requests = []
|
sampled_requests = []
|
||||||
for item in self.data:
|
for item in self.data:
|
||||||
if len(sampled_requests) >= num_requests:
|
if len(sampled_requests) >= num_requests:
|
||||||
break
|
break
|
||||||
prompt = item["turns"][0][0]["content"]
|
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
|
||||||
|
if parser_fn is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported dataset path: {self.dataset_path}")
|
||||||
|
prompt = parser_fn(item)
|
||||||
mm_content = process_image(item["images"][0])
|
mm_content = process_image(item["images"][0])
|
||||||
prompt_len = len(tokenizer(prompt).input_ids)
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
if enable_multimodal_chat:
|
if enable_multimodal_chat:
|
||||||
@ -685,4 +720,98 @@ class VisionArenaDataset(HuggingFaceDataset):
|
|||||||
expected_output_len=output_len,
|
expected_output_len=output_len,
|
||||||
multi_modal_data=mm_content,
|
multi_modal_data=mm_content,
|
||||||
))
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Instruct Coder Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class InstructCoderDataset(HuggingFaceDataset):
|
||||||
|
"""
|
||||||
|
InstructCoder Dataset.
|
||||||
|
https://huggingface.co/datasets/likaixin/InstructCoder
|
||||||
|
|
||||||
|
InstructCoder is the dataset designed for general code editing. It consists
|
||||||
|
of 114,239 instruction-input-output triplets, and covers multiple distinct
|
||||||
|
code editing scenario.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT_LEN = 200 # this is the average default output length
|
||||||
|
SUPPORTED_DATASET_PATHS = {
|
||||||
|
"likaixin/InstructCoder",
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample(self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
output_len: Optional[int] = None,
|
||||||
|
enable_multimodal_chat: bool = False,
|
||||||
|
**kwargs) -> list:
|
||||||
|
output_len = (output_len
|
||||||
|
if output_len is not None else self.DEFAULT_OUTPUT_LEN)
|
||||||
|
sampled_requests = []
|
||||||
|
for item in self.data:
|
||||||
|
if len(sampled_requests) >= num_requests:
|
||||||
|
break
|
||||||
|
prompt = f"{item['instruction']}:\n{item['input']}"
|
||||||
|
prompt_len = len(tokenizer(prompt).input_ids)
|
||||||
|
sampled_requests.append(
|
||||||
|
SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len,
|
||||||
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
|
return sampled_requests
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# AIMO Dataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AIMODataset(HuggingFaceDataset):
|
||||||
|
"""
|
||||||
|
Dataset class for processing a AIMO dataset with reasoning questions.
|
||||||
|
"""
|
||||||
|
SUPPORTED_DATASET_PATHS = {
|
||||||
|
"AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5",
|
||||||
|
"AI-MO/NuminaMath-CoT"
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample(self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
output_len: Optional[int] = None,
|
||||||
|
**kwargs) -> list:
|
||||||
|
sampled_requests = []
|
||||||
|
dynamic_output = output_len is None
|
||||||
|
|
||||||
|
for item in self.data:
|
||||||
|
if len(sampled_requests) >= num_requests:
|
||||||
|
break
|
||||||
|
prompt, completion = item['problem'], item["solution"]
|
||||||
|
|
||||||
|
prompt_ids = tokenizer(prompt).input_ids
|
||||||
|
completion_ids = tokenizer(completion).input_ids
|
||||||
|
prompt_len = len(prompt_ids)
|
||||||
|
completion_len = len(completion_ids)
|
||||||
|
output_len = completion_len if dynamic_output else output_len
|
||||||
|
assert isinstance(output_len, int) and output_len > 0
|
||||||
|
if dynamic_output and not is_valid_sequence(prompt_len,
|
||||||
|
completion_len,
|
||||||
|
max_prompt_len=2048,
|
||||||
|
max_total_len=32000):
|
||||||
|
continue
|
||||||
|
sampled_requests.append(
|
||||||
|
SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
expected_output_len=output_len,
|
||||||
|
multi_modal_data=None,
|
||||||
|
))
|
||||||
|
self.maybe_oversample_requests(sampled_requests, num_requests)
|
||||||
return sampled_requests
|
return sampled_requests
|
||||||
|
|||||||
@ -7,9 +7,6 @@ On the server side, run one of the following commands:
|
|||||||
--swap-space 16 \
|
--swap-space 16 \
|
||||||
--disable-log-requests
|
--disable-log-requests
|
||||||
|
|
||||||
(TGI backend)
|
|
||||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
|
||||||
|
|
||||||
On the client side, run:
|
On the client side, run:
|
||||||
python benchmarks/benchmark_serving.py \
|
python benchmarks/benchmark_serving.py \
|
||||||
--backend <backend> \
|
--backend <backend> \
|
||||||
@ -37,7 +34,8 @@ from datetime import datetime
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
from backend_request_func import (ASYNC_REQUEST_FUNCS,
|
||||||
|
OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput,
|
||||||
RequestFuncOutput)
|
RequestFuncOutput)
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@ -52,9 +50,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
ConversationDataset, HuggingFaceDataset,
|
||||||
SonnetDataset, VisionArenaDataset)
|
InstructCoderDataset, RandomDataset,
|
||||||
|
SampleRequest, ShareGPTDataset, SonnetDataset,
|
||||||
|
VisionArenaDataset)
|
||||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||||
|
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||||
@ -156,7 +156,7 @@ def calculate_metrics(
|
|||||||
if outputs[i].success:
|
if outputs[i].success:
|
||||||
output_len = outputs[i].output_tokens
|
output_len = outputs[i].output_tokens
|
||||||
|
|
||||||
if output_len is None:
|
if not output_len:
|
||||||
# We use the tokenizer to count the number of output tokens
|
# We use the tokenizer to count the number of output tokens
|
||||||
# for some serving backends instead of looking at
|
# for some serving backends instead of looking at
|
||||||
# len(outputs[i].itl) since multiple output tokens may be
|
# len(outputs[i].itl) since multiple output tokens may be
|
||||||
@ -261,6 +261,7 @@ async def benchmark(
|
|||||||
goodput_config_dict: dict[str, float],
|
goodput_config_dict: dict[str, float],
|
||||||
max_concurrency: Optional[int],
|
max_concurrency: Optional[int],
|
||||||
lora_modules: Optional[Iterable[str]],
|
lora_modules: Optional[Iterable[str]],
|
||||||
|
extra_body: Optional[dict],
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
@ -288,6 +289,7 @@ async def benchmark(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
multi_modal_content=test_mm_content,
|
multi_modal_content=test_mm_content,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_output = await request_func(request_func_input=test_input)
|
test_output = await request_func(request_func_input=test_input)
|
||||||
@ -314,7 +316,8 @@ async def benchmark(
|
|||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
multi_modal_content=test_mm_content,
|
multi_modal_content=test_mm_content,
|
||||||
ignore_eos=ignore_eos)
|
ignore_eos=ignore_eos,
|
||||||
|
extra_body=extra_body)
|
||||||
profile_output = await request_func(request_func_input=profile_input)
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
print("Profiler started")
|
print("Profiler started")
|
||||||
@ -364,7 +367,8 @@ async def benchmark(
|
|||||||
output_len=output_len,
|
output_len=output_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
multi_modal_content=mm_content,
|
multi_modal_content=mm_content,
|
||||||
ignore_eos=ignore_eos)
|
ignore_eos=ignore_eos,
|
||||||
|
extra_body=extra_body)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
limited_request_func(request_func_input=request_func_input,
|
limited_request_func(request_func_input=request_func_input,
|
||||||
@ -586,19 +590,39 @@ def main(args: argparse.Namespace):
|
|||||||
return_prompt_formatted=True)
|
return_prompt_formatted=True)
|
||||||
|
|
||||||
elif args.dataset_name == "hf":
|
elif args.dataset_name == "hf":
|
||||||
# Choose between VisionArenaDataset
|
# all following datasets are implemented from the
|
||||||
# and HuggingFaceDataset based on provided parameters.
|
# HuggingFaceDataset base class
|
||||||
dataset_class = (VisionArenaDataset if args.dataset_path
|
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
dataset_class = VisionArenaDataset
|
||||||
and args.hf_subset is None else HuggingFaceDataset)
|
args.hf_split = "train"
|
||||||
|
args.hf_subset = None
|
||||||
|
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = InstructCoderDataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = ConversationDataset
|
||||||
|
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_class = AIMODataset
|
||||||
|
args.hf_split = "train"
|
||||||
|
else:
|
||||||
|
supported_datasets = set([
|
||||||
|
dataset_name for cls in HuggingFaceDataset.__subclasses__()
|
||||||
|
for dataset_name in cls.SUPPORTED_DATASET_PATHS
|
||||||
|
])
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported dataset path: {args.dataset_path}. "
|
||||||
|
"Huggingface dataset only supports dataset_path"
|
||||||
|
f" from one of following: {supported_datasets}. "
|
||||||
|
"Please consider contributing if you would "
|
||||||
|
"like to add support for additional dataset formats.")
|
||||||
input_requests = dataset_class(
|
input_requests = dataset_class(
|
||||||
dataset_path=args.dataset_path,
|
dataset_path=args.dataset_path,
|
||||||
dataset_subset=args.hf_subset,
|
dataset_subset=args.hf_subset,
|
||||||
dataset_split=args.hf_split,
|
dataset_split=args.hf_split,
|
||||||
|
random_seed=args.seed,
|
||||||
).sample(
|
).sample(
|
||||||
num_requests=args.num_prompts,
|
num_requests=args.num_prompts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
random_seed=args.seed,
|
|
||||||
output_len=args.hf_output_len,
|
output_len=args.hf_output_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -633,6 +657,26 @@ def main(args: argparse.Namespace):
|
|||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||||
goodput_config_dict = check_goodput_args(args)
|
goodput_config_dict = check_goodput_args(args)
|
||||||
|
|
||||||
|
# Collect the sampling parameters.
|
||||||
|
sampling_params = {
|
||||||
|
k: v
|
||||||
|
for k, v in {
|
||||||
|
"top_p": args.top_p,
|
||||||
|
"top_k": args.top_k,
|
||||||
|
"min_p": args.min_p,
|
||||||
|
"temperature": args.temperature
|
||||||
|
}.items() if v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sampling parameters are only supported by openai-compatible backend.
|
||||||
|
if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
|
||||||
|
raise ValueError(
|
||||||
|
"Sampling parameters are only supported by openai-compatible "
|
||||||
|
"backends.")
|
||||||
|
|
||||||
|
if "temperature" not in sampling_params:
|
||||||
|
sampling_params["temperature"] = 0.0 # Default to greedy decoding.
|
||||||
|
|
||||||
# Avoid GC processing "static" data - reduce pause times.
|
# Avoid GC processing "static" data - reduce pause times.
|
||||||
gc.collect()
|
gc.collect()
|
||||||
gc.freeze()
|
gc.freeze()
|
||||||
@ -659,6 +703,7 @@ def main(args: argparse.Namespace):
|
|||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=args.lora_modules,
|
||||||
|
extra_body=sampling_params,
|
||||||
))
|
))
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
@ -876,7 +921,7 @@ if __name__ == "__main__":
|
|||||||
"--percentile-metrics",
|
"--percentile-metrics",
|
||||||
type=str,
|
type=str,
|
||||||
default="ttft,tpot,itl",
|
default="ttft,tpot,itl",
|
||||||
help="Comma-seperated list of selected metrics to report percentils. "
|
help="Comma-separated list of selected metrics to report percentils. "
|
||||||
"This argument specifies the metrics to report percentiles. "
|
"This argument specifies the metrics to report percentiles. "
|
||||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||||
"Default value is \"ttft,tpot,itl\".")
|
"Default value is \"ttft,tpot,itl\".")
|
||||||
@ -884,7 +929,7 @@ if __name__ == "__main__":
|
|||||||
"--metric-percentiles",
|
"--metric-percentiles",
|
||||||
type=str,
|
type=str,
|
||||||
default="99",
|
default="99",
|
||||||
help="Comma-seperated list of percentiles for selected metrics. "
|
help="Comma-separated list of percentiles for selected metrics. "
|
||||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||||
"Default value is \"99\". "
|
"Default value is \"99\". "
|
||||||
"Use \"--percentile-metrics\" to select metrics.",
|
"Use \"--percentile-metrics\" to select metrics.",
|
||||||
@ -951,18 +996,23 @@ if __name__ == "__main__":
|
|||||||
random_group.add_argument(
|
random_group.add_argument(
|
||||||
"--random-range-ratio",
|
"--random-range-ratio",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=0.0,
|
||||||
help="Range of sampled ratio of input/output length, "
|
help="Range ratio for sampling input/output length, "
|
||||||
"used only for random sampling.",
|
"used only for random sampling. Must be in the range [0, 1) to define "
|
||||||
|
"a symmetric sampling range"
|
||||||
|
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||||
)
|
)
|
||||||
random_group.add_argument(
|
random_group.add_argument(
|
||||||
"--random-prefix-len",
|
"--random-prefix-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="Number of fixed prefix tokens before random "
|
help=("Number of fixed prefix tokens before the random context "
|
||||||
" context. The length range of context in a random "
|
"in a request. "
|
||||||
" request is [random-prefix-len, "
|
"The total input length is the sum of `random-prefix-len` and "
|
||||||
" random-prefix-len + random-prefix-len * random-range-ratio).")
|
"a random "
|
||||||
|
"context length sampled from [input_len * (1 - range_ratio), "
|
||||||
|
"input_len * (1 + range_ratio)]."),
|
||||||
|
)
|
||||||
|
|
||||||
hf_group = parser.add_argument_group("hf dataset options")
|
hf_group = parser.add_argument_group("hf dataset options")
|
||||||
hf_group.add_argument("--hf-subset",
|
hf_group.add_argument("--hf-subset",
|
||||||
@ -981,6 +1031,33 @@ if __name__ == "__main__":
|
|||||||
"from the sampled HF dataset.",
|
"from the sampled HF dataset.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sampling_group = parser.add_argument_group("sampling parameters")
|
||||||
|
sampling_group.add_argument(
|
||||||
|
"--top-p",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Top-p sampling parameter. Only has effect on openai-compatible "
|
||||||
|
"backends.")
|
||||||
|
sampling_group.add_argument(
|
||||||
|
"--top-k",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Top-k sampling parameter. Only has effect on openai-compatible "
|
||||||
|
"backends.")
|
||||||
|
sampling_group.add_argument(
|
||||||
|
"--min-p",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Min-p sampling parameter. Only has effect on openai-compatible "
|
||||||
|
"backends.")
|
||||||
|
sampling_group.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Temperature sampling parameter. Only has effect on "
|
||||||
|
"openai-compatible backends. If not specified, default to greedy "
|
||||||
|
"decoding (i.e. temperature==0.0).")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer-mode',
|
'--tokenizer-mode',
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@ -5,16 +5,13 @@ On the server side, run one of the following commands:
|
|||||||
(vLLM OpenAI API server)
|
(vLLM OpenAI API server)
|
||||||
vllm serve <your_model> --disable-log-requests
|
vllm serve <your_model> --disable-log-requests
|
||||||
|
|
||||||
(TGI backend)
|
|
||||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
|
||||||
|
|
||||||
On the client side, run:
|
On the client side, run:
|
||||||
python benchmarks/benchmark_serving_structured_output.py \
|
python benchmarks/benchmark_serving_structured_output.py \
|
||||||
--backend <backend> \
|
--backend <backend> \
|
||||||
--model <your_model> \
|
--model <your_model> \
|
||||||
--dataset json \
|
--dataset json \
|
||||||
--structured-output-ratio 1.0 \
|
--structured-output-ratio 1.0 \
|
||||||
--structured-output-backend xgrammar \
|
--structured-output-backend auto \
|
||||||
--request-rate 10 \
|
--request-rate 10 \
|
||||||
--num-prompts 1000
|
--num-prompts 1000
|
||||||
|
|
||||||
@ -133,10 +130,11 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
"description":
|
"description":
|
||||||
"An unique optional field to avoid cached schemas"
|
"An unique optional field to avoid cached schemas"
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
json_schemas = [schema] * args.num_prompts
|
||||||
|
|
||||||
def gen_prompt(index: int):
|
def gen_prompt(index: int):
|
||||||
schema = json_schemas[index % len(json_schemas)]
|
return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501
|
||||||
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
|
||||||
|
|
||||||
def get_schema(index: int):
|
def get_schema(index: int):
|
||||||
return json_schemas[index % len(json_schemas)]
|
return json_schemas[index % len(json_schemas)]
|
||||||
@ -732,8 +730,11 @@ def main(args: argparse.Namespace):
|
|||||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
base_url = f"http://{args.host}:{args.port}"
|
base_url = f"http://{args.host}:{args.port}"
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_id,
|
tokenizer = get_tokenizer(
|
||||||
trust_remote_code=args.trust_remote_code)
|
tokenizer_id,
|
||||||
|
trust_remote_code=args.trust_remote_code,
|
||||||
|
tokenizer_mode=args.tokenizer_mode,
|
||||||
|
)
|
||||||
|
|
||||||
if args.dataset == 'grammar':
|
if args.dataset == 'grammar':
|
||||||
args.structure_type = 'guided_grammar'
|
args.structure_type = 'guided_grammar'
|
||||||
@ -876,6 +877,13 @@ if __name__ == "__main__":
|
|||||||
help=
|
help=
|
||||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer-mode",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help=
|
||||||
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-prompts",
|
"--num-prompts",
|
||||||
type=int,
|
type=int,
|
||||||
@ -956,7 +964,7 @@ if __name__ == "__main__":
|
|||||||
"--percentile-metrics",
|
"--percentile-metrics",
|
||||||
type=str,
|
type=str,
|
||||||
default="ttft,tpot,itl",
|
default="ttft,tpot,itl",
|
||||||
help="Comma-seperated list of selected metrics to report percentils. "
|
help="Comma-separated list of selected metrics to report percentils. "
|
||||||
"This argument specifies the metrics to report percentiles. "
|
"This argument specifies the metrics to report percentiles. "
|
||||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||||
"Default value is \"ttft,tpot,itl\".")
|
"Default value is \"ttft,tpot,itl\".")
|
||||||
@ -964,7 +972,7 @@ if __name__ == "__main__":
|
|||||||
"--metric-percentiles",
|
"--metric-percentiles",
|
||||||
type=str,
|
type=str,
|
||||||
default="99",
|
default="99",
|
||||||
help="Comma-seperated list of percentiles for selected metrics. "
|
help="Comma-separated list of percentiles for selected metrics. "
|
||||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||||
"Default value is \"99\". "
|
"Default value is \"99\". "
|
||||||
"Use \"--percentile-metrics\" to select metrics.",
|
"Use \"--percentile-metrics\" to select metrics.",
|
||||||
@ -991,8 +999,11 @@ if __name__ == "__main__":
|
|||||||
help="Ratio of Structured Outputs requests")
|
help="Ratio of Structured Outputs requests")
|
||||||
parser.add_argument("--structured-output-backend",
|
parser.add_argument("--structured-output-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
choices=[
|
||||||
default="xgrammar",
|
"outlines", "lm-format-enforcer", "xgrammar",
|
||||||
|
"guidance", "auto"
|
||||||
|
],
|
||||||
|
default="auto",
|
||||||
help="Backend to use for structured outputs")
|
help="Backend to use for structured outputs")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@ -11,7 +11,8 @@ from typing import Any, Optional, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvloop
|
import uvloop
|
||||||
from benchmark_dataset import (BurstGPTDataset, HuggingFaceDataset,
|
from benchmark_dataset import (AIMODataset, BurstGPTDataset,
|
||||||
|
ConversationDataset, InstructCoderDataset,
|
||||||
RandomDataset, SampleRequest, ShareGPTDataset,
|
RandomDataset, SampleRequest, ShareGPTDataset,
|
||||||
SonnetDataset, VisionArenaDataset)
|
SonnetDataset, VisionArenaDataset)
|
||||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||||
@ -212,14 +213,17 @@ def run_hf(
|
|||||||
max_prompt_len = 0
|
max_prompt_len = 0
|
||||||
max_output_len = 0
|
max_output_len = 0
|
||||||
for i in range(len(requests)):
|
for i in range(len(requests)):
|
||||||
prompt, prompt_len, output_len = requests[i]
|
prompt = requests[i].prompt
|
||||||
|
prompt_len = requests[i].prompt_len
|
||||||
|
output_len = requests[i].expected_output_len
|
||||||
# Add the prompt to the batch.
|
# Add the prompt to the batch.
|
||||||
batch.append(prompt)
|
batch.append(prompt)
|
||||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
max_prompt_len = max(max_prompt_len, prompt_len)
|
||||||
max_output_len = max(max_output_len, output_len)
|
max_output_len = max(max_output_len, output_len)
|
||||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||||
# Check if we can add more requests to the batch.
|
# Check if we can add more requests to the batch.
|
||||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
next_prompt_len = requests[i + 1].prompt_len
|
||||||
|
next_output_len = requests[i + 1].expected_output_len
|
||||||
if (max(max_prompt_len, next_prompt_len) +
|
if (max(max_prompt_len, next_prompt_len) +
|
||||||
max(max_output_len, next_output_len)) <= 2048:
|
max(max_output_len, next_output_len)) <= 2048:
|
||||||
# We can add more requests to the batch.
|
# We can add more requests to the batch.
|
||||||
@ -300,6 +304,7 @@ def get_requests(args, tokenizer):
|
|||||||
"input_len": args.input_len,
|
"input_len": args.input_len,
|
||||||
"output_len": args.output_len,
|
"output_len": args.output_len,
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.dataset_path is None or args.dataset_name == "random":
|
if args.dataset_path is None or args.dataset_name == "random":
|
||||||
sample_kwargs["range_ratio"] = args.random_range_ratio
|
sample_kwargs["range_ratio"] = args.random_range_ratio
|
||||||
sample_kwargs["prefix_len"] = args.prefix_len
|
sample_kwargs["prefix_len"] = args.prefix_len
|
||||||
@ -317,18 +322,23 @@ def get_requests(args, tokenizer):
|
|||||||
elif args.dataset_name == "burstgpt":
|
elif args.dataset_name == "burstgpt":
|
||||||
dataset_cls = BurstGPTDataset
|
dataset_cls = BurstGPTDataset
|
||||||
elif args.dataset_name == "hf":
|
elif args.dataset_name == "hf":
|
||||||
if args.backend != "vllm-chat":
|
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||||
raise ValueError(
|
dataset_cls = VisionArenaDataset
|
||||||
"hf datasets only are supported by vllm-chat backend")
|
common_kwargs['dataset_subset'] = None
|
||||||
# Choose between VisionArenaDataset and HuggingFaceDataset based on
|
common_kwargs['dataset_split'] = "train"
|
||||||
# provided parameters.
|
sample_kwargs["enable_multimodal_chat"] = True
|
||||||
dataset_cls = (VisionArenaDataset if args.dataset_path
|
elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
|
||||||
== VisionArenaDataset.VISION_ARENA_DATASET_PATH
|
dataset_cls = InstructCoderDataset
|
||||||
and args.hf_subset is None else HuggingFaceDataset)
|
common_kwargs['dataset_split'] = "train"
|
||||||
common_kwargs['dataset_subset'] = args.hf_subset
|
elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
|
||||||
common_kwargs['dataset_split'] = args.hf_split
|
dataset_cls = ConversationDataset
|
||||||
sample_kwargs["enable_multimodal_chat"] = True
|
common_kwargs['dataset_subset'] = args.hf_subset
|
||||||
|
common_kwargs['dataset_split'] = args.hf_split
|
||||||
|
sample_kwargs["enable_multimodal_chat"] = True
|
||||||
|
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
|
||||||
|
dataset_cls = AIMODataset
|
||||||
|
common_kwargs['dataset_subset'] = None
|
||||||
|
common_kwargs['dataset_split'] = "train"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||||
# Remove None values
|
# Remove None values
|
||||||
@ -462,9 +472,17 @@ def validate_args(args):
|
|||||||
warnings.warn("--hf-subset and --hf-split will be ignored \
|
warnings.warn("--hf-subset and --hf-split will be ignored \
|
||||||
since --dataset-name is not 'hf'.",
|
since --dataset-name is not 'hf'.",
|
||||||
stacklevel=2)
|
stacklevel=2)
|
||||||
elif args.dataset_name == "hf" and args.backend != "vllm-chat":
|
elif args.dataset_name == "hf":
|
||||||
raise ValueError(
|
if args.dataset_path in (
|
||||||
"When --dataset-name is 'hf', backend must be 'vllm-chat'")
|
VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
|
||||||
|
| ConversationDataset.SUPPORTED_DATASET_PATHS):
|
||||||
|
assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501
|
||||||
|
elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS
|
||||||
|
| AIMODataset.SUPPORTED_DATASET_PATHS):
|
||||||
|
assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{args.dataset_path} is not supported by hf dataset.")
|
||||||
|
|
||||||
# --random-range-ratio: only used when dataset_name is 'random'
|
# --random-range-ratio: only used when dataset_name is 'random'
|
||||||
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
if args.dataset_name != 'random' and args.random_range_ratio is not None:
|
||||||
@ -576,18 +594,30 @@ if __name__ == "__main__":
|
|||||||
default=None,
|
default=None,
|
||||||
help="Path to the lora adapters to use. This can be an absolute path, "
|
help="Path to the lora adapters to use. This can be an absolute path, "
|
||||||
"a relative path, or a Hugging Face model identifier.")
|
"a relative path, or a Hugging Face model identifier.")
|
||||||
parser.add_argument("--prefix-len",
|
parser.add_argument(
|
||||||
type=int,
|
"--prefix-len",
|
||||||
default=None,
|
type=int,
|
||||||
help="Number of prefix tokens per request."
|
default=None,
|
||||||
"This is for the RandomDataset and SonnetDataset")
|
help=f"Number of prefix tokens to be used in RandomDataset "
|
||||||
|
"and SonnetDataset. For RandomDataset, the total input "
|
||||||
|
"length is the sum of prefix-len (default: "
|
||||||
|
f"{RandomDataset.DEFAULT_PREFIX_LEN}) and a random context length "
|
||||||
|
"sampled from [input_len * (1 - range_ratio), "
|
||||||
|
"input_len * (1 + range_ratio)]. For SonnetDataset, "
|
||||||
|
f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) "
|
||||||
|
"controls how much of the input is fixed lines versus "
|
||||||
|
"random lines, but the total input length remains approximately "
|
||||||
|
"input_len tokens.")
|
||||||
# random dataset
|
# random dataset
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--random-range-ratio",
|
"--random-range-ratio",
|
||||||
type=float,
|
type=float,
|
||||||
default=None,
|
default=None,
|
||||||
help="Range of sampled ratio of input/output length, "
|
help=f"Range ratio (default : {RandomDataset.DEFAULT_RANGE_RATIO}) "
|
||||||
"used only for RandomDataSet.",
|
"for sampling input/output length, "
|
||||||
|
"used only for RandomDataset. Must be in the range [0, 1) to "
|
||||||
|
"define a symmetric sampling range "
|
||||||
|
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
|
||||||
)
|
)
|
||||||
|
|
||||||
# hf dtaset
|
# hf dtaset
|
||||||
|
|||||||
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
340
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8,
|
||||||
|
fused_experts,
|
||||||
|
fused_topk)
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
DEFAULT_MODELS = [
|
||||||
|
"nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite",
|
||||||
|
"ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m"
|
||||||
|
]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
PER_ACT_TOKEN_OPTS = [False]
|
||||||
|
PER_OUT_CH_OPTS = [False]
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.Tensor):
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(
|
||||||
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||||
|
num_experts: int, topk: int, per_act_token: bool,
|
||||||
|
per_out_ch: bool, mkn: tuple[int, int, int]):
|
||||||
|
label = "Quant Matmul"
|
||||||
|
|
||||||
|
sub_label = (
|
||||||
|
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, "
|
||||||
|
"MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch,
|
||||||
|
mkn))
|
||||||
|
|
||||||
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
|
(m, k, n) = mkn
|
||||||
|
|
||||||
|
dtype = torch.half
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
|
_, a_scale = ops.scaled_fp8_quant(a)
|
||||||
|
|
||||||
|
w1_q = torch.empty((num_experts, 2 * n, k),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
w2_q = torch.empty((num_experts, k, n),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.empty((num_experts, 1, 1),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
w2_scale = torch.empty((num_experts, 1, 1),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
ab_strides1 = torch.full((num_experts, ),
|
||||||
|
k,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((num_experts, ),
|
||||||
|
2 * n,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((num_experts, ),
|
||||||
|
n,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((num_experts, ),
|
||||||
|
k,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
|
for expert in range(num_experts):
|
||||||
|
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||||
|
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||||
|
w1_q_notransp = w1_q.clone()
|
||||||
|
w2_q_notransp = w2_q.clone()
|
||||||
|
w1_q = w1_q.transpose(1, 2)
|
||||||
|
w2_q = w2_q.transpose(1, 2)
|
||||||
|
|
||||||
|
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
|
||||||
|
|
||||||
|
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor, num_repeats: int):
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
fused_experts(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_scale)
|
||||||
|
|
||||||
|
def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor,
|
||||||
|
w1: torch.Tensor, w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor, c_strides2: torch.Tensor,
|
||||||
|
num_repeats: int):
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
cutlass_moe_fp8(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
ab_strides1,
|
||||||
|
c_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides2,
|
||||||
|
a1_scale=a_scale)
|
||||||
|
|
||||||
|
def run_cutlass_from_graph(
|
||||||
|
a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor,
|
||||||
|
w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor, c_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor, c_strides2: torch.Tensor):
|
||||||
|
with set_current_vllm_config(
|
||||||
|
VllmConfig(parallel_config=ParallelConfig(
|
||||||
|
pipeline_parallel_size=1))):
|
||||||
|
return cutlass_moe_fp8(a,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
ab_strides1,
|
||||||
|
c_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides2,
|
||||||
|
a1_scale=a_scale)
|
||||||
|
|
||||||
|
def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor, topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor, w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor, a_scale: torch.Tensor):
|
||||||
|
with set_current_vllm_config(
|
||||||
|
VllmConfig(parallel_config=ParallelConfig(
|
||||||
|
pipeline_parallel_size=1))):
|
||||||
|
return fused_experts(a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_scale)
|
||||||
|
|
||||||
|
def replay_graph(graph, num_repeats):
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
cutlass_stream = torch.cuda.Stream()
|
||||||
|
cutlass_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||||
|
run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale,
|
||||||
|
topk_weights, topk_ids, ab_strides1, c_strides1,
|
||||||
|
ab_strides2, c_strides2)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
triton_stream = torch.cuda.Stream()
|
||||||
|
triton_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||||
|
run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights,
|
||||||
|
topk_ids, w1_scale, w2_scale, a_scale)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
min_run_time = 5
|
||||||
|
num_warmup = 5
|
||||||
|
num_runs = 25
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
# Baseline params
|
||||||
|
"w1": w1,
|
||||||
|
"w2": w2,
|
||||||
|
"score": score,
|
||||||
|
"topk": topk,
|
||||||
|
"w1_q_notransp": w1_q_notransp,
|
||||||
|
"w2_q_notransp": w2_q_notransp,
|
||||||
|
# Cutlass params
|
||||||
|
"a_scale": a_scale,
|
||||||
|
"w1_q": w1_q,
|
||||||
|
"w2_q": w2_q,
|
||||||
|
"w1_scale": w1_scale,
|
||||||
|
"w2_scale": w2_scale,
|
||||||
|
"ab_strides1": ab_strides1,
|
||||||
|
"c_strides1": c_strides1,
|
||||||
|
"ab_strides2": ab_strides2,
|
||||||
|
"c_strides2": c_strides2,
|
||||||
|
# cuda graph params
|
||||||
|
"cutlass_graph": cutlass_graph,
|
||||||
|
"triton_graph": triton_graph,
|
||||||
|
# Gen params
|
||||||
|
"a": a,
|
||||||
|
"topk_weights": topk_weights,
|
||||||
|
"topk_ids": topk_ids,
|
||||||
|
"num_runs": num_runs,
|
||||||
|
# Kernels
|
||||||
|
"run_triton_moe": run_triton_moe,
|
||||||
|
"run_cutlass_moe": run_cutlass_moe,
|
||||||
|
"replay_graph": replay_graph,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids,
|
||||||
|
w1_scale, w2_scale, a_scale, num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="triton_moe",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
replay_graph(triton_graph, num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="replay_graph(triton_graph, num_runs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="triton_moe_cuda_graphs",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights,
|
||||||
|
topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2,
|
||||||
|
num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt=
|
||||||
|
"run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="grouped_gemm_moe",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
replay_graph(cutlass_graph, num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="grouped_gemm_moe_cuda_graphs",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
results: list[benchmark.Measurement] = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for tp in args.tp_sizes:
|
||||||
|
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||||
|
num_experts = layer[0]
|
||||||
|
topk = layer[1]
|
||||||
|
size_k = layer[2]
|
||||||
|
size_n = layer[3] // tp
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||||
|
for per_out_ch in PER_OUT_CH_OPTS:
|
||||||
|
for size_m in DEFAULT_BATCH_SIZES:
|
||||||
|
mkn = (size_m, size_k, size_n)
|
||||||
|
bench_run(results, model, num_experts, topk,
|
||||||
|
per_act_token, per_out_ch, mkn)
|
||||||
|
|
||||||
|
compare = benchmark.Compare(results)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark Marlin across specified models/shapes/batches")
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_TP_SIZES)
|
||||||
|
parser.add_argument("--batch-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=DEFAULT_BATCH_SIZES)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-per-act-token",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=[])
|
||||||
|
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@ -17,13 +17,8 @@ from torch.utils.benchmark import Measurement as TMeasurement
|
|||||||
from utils import ArgPool, Bench, CudaGraphBenchParams
|
from utils import ArgPool, Bench, CudaGraphBenchParams
|
||||||
from weight_shapes import WEIGHT_SHAPES
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
from vllm.lora.ops.triton_ops.bgmv_expand import bgmv_expand
|
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
|
||||||
from vllm.lora.ops.triton_ops.bgmv_expand_slice import bgmv_expand_slice
|
|
||||||
from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
|
|
||||||
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
|
|
||||||
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
|
|
||||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||||
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||||
@ -167,69 +162,25 @@ class OpType(Enum):
|
|||||||
"""
|
"""
|
||||||
LoRA Ops to benchmark and its properties.
|
LoRA Ops to benchmark and its properties.
|
||||||
"""
|
"""
|
||||||
SGMV_SHRINK = auto()
|
LORA_SHRINK = auto()
|
||||||
BGMV_SHRINK = auto()
|
LORA_EXPAND = auto()
|
||||||
SGMV_EXPAND = auto()
|
|
||||||
BGMV_EXPAND = auto()
|
|
||||||
BGMV_EXPAND_SLICE = auto()
|
|
||||||
V1_SHRINK = auto()
|
|
||||||
V1_EXPAND = auto()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_str(s: str) -> "OpType":
|
def from_str(s: str) -> "OpType":
|
||||||
if s.lower() == 'sgmv_shrink':
|
if s.lower() == "lora_shrink":
|
||||||
return OpType.SGMV_SHRINK
|
return OpType.LORA_SHRINK
|
||||||
if s.lower() == 'sgmv_expand':
|
if s.lower() == "lora_expand":
|
||||||
return OpType.SGMV_EXPAND
|
return OpType.LORA_EXPAND
|
||||||
if s.lower() == 'bgmv_shrink':
|
|
||||||
return OpType.BGMV_SHRINK
|
|
||||||
if s.lower() == 'bgmv_expand':
|
|
||||||
return OpType.BGMV_EXPAND
|
|
||||||
if s.lower() == "bgmv_expand_slice":
|
|
||||||
return OpType.BGMV_EXPAND_SLICE
|
|
||||||
if s.lower() == "v1_shrink":
|
|
||||||
return OpType.V1_SHRINK
|
|
||||||
if s.lower() == "v1_expand":
|
|
||||||
return OpType.V1_EXPAND
|
|
||||||
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
raise ValueError(f"Unrecognized str {s} to convert to OpType")
|
||||||
|
|
||||||
def is_shrink_fn(self) -> bool:
|
def is_shrink_fn(self) -> bool:
|
||||||
return self in [
|
return self in [OpType.LORA_SHRINK]
|
||||||
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_expand_fn(self) -> bool:
|
def is_expand_fn(self) -> bool:
|
||||||
return self in [
|
return self in [OpType.LORA_EXPAND]
|
||||||
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_prefill_op(self) -> bool:
|
|
||||||
return self in [
|
|
||||||
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
|
|
||||||
OpType.V1_EXPAND
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_decode_op(self) -> bool:
|
|
||||||
return self in [
|
|
||||||
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
|
|
||||||
OpType.V1_SHRINK, OpType.V1_EXPAND
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_expand_slice_fn(self) -> bool:
|
|
||||||
return self in [OpType.BGMV_EXPAND_SLICE]
|
|
||||||
|
|
||||||
def num_slices(self) -> list[int]:
|
def num_slices(self) -> list[int]:
|
||||||
if self in [
|
return [1, 2, 3]
|
||||||
OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
|
|
||||||
OpType.V1_EXPAND
|
|
||||||
]:
|
|
||||||
# SGMV kernels and v1 kernels supports slices
|
|
||||||
return [1, 2, 3]
|
|
||||||
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
|
|
||||||
return [1]
|
|
||||||
if self in [OpType.BGMV_EXPAND_SLICE]:
|
|
||||||
return [2, 3]
|
|
||||||
raise ValueError(f"Unrecognized OpType {self}")
|
|
||||||
|
|
||||||
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
|
||||||
lora_rank: int) -> tuple[int, int, int]:
|
lora_rank: int) -> tuple[int, int, int]:
|
||||||
@ -239,7 +190,7 @@ class OpType(Enum):
|
|||||||
k = hidden_size
|
k = hidden_size
|
||||||
n = lora_rank
|
n = lora_rank
|
||||||
else:
|
else:
|
||||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
assert self.is_expand_fn()
|
||||||
m = num_tokens
|
m = num_tokens
|
||||||
k = lora_rank
|
k = lora_rank
|
||||||
n = hidden_size
|
n = hidden_size
|
||||||
@ -254,7 +205,7 @@ class OpType(Enum):
|
|||||||
if self.is_shrink_fn():
|
if self.is_shrink_fn():
|
||||||
return op_dtype, op_dtype, torch.float32
|
return op_dtype, op_dtype, torch.float32
|
||||||
else:
|
else:
|
||||||
assert self.is_expand_fn() or self.is_expand_slice_fn()
|
assert self.is_expand_fn()
|
||||||
return torch.float32, op_dtype, op_dtype
|
return torch.float32, op_dtype, op_dtype
|
||||||
|
|
||||||
def matmul_shapes(
|
def matmul_shapes(
|
||||||
@ -268,43 +219,19 @@ class OpType(Enum):
|
|||||||
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
|
||||||
|
|
||||||
b_shape = (num_loras, n, k) # col-major
|
b_shape = (num_loras, n, k) # col-major
|
||||||
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
|
if self in [OpType.LORA_SHRINK]:
|
||||||
# SGMV shrink and V1 shrink kernels support num_slices inherently
|
# LoRA shrink kernels support num_slices inherently in the kernel.
|
||||||
# in the kernel.
|
|
||||||
return ((m, k), b_shape, (num_slices, m, n))
|
return ((m, k), b_shape, (num_slices, m, n))
|
||||||
if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
|
if self in [OpType.LORA_EXPAND]:
|
||||||
# SGMV expand and V1 expand kernels support num_slices inherently
|
# LoRA expand kernels support num_slices inherently in the kernel
|
||||||
# in the kernel
|
|
||||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
||||||
if self == OpType.BGMV_SHRINK:
|
|
||||||
return ((m, k), b_shape, (m, n))
|
|
||||||
if self == OpType.BGMV_EXPAND:
|
|
||||||
return ((m, k), b_shape, (m, n))
|
|
||||||
if self == OpType.BGMV_EXPAND_SLICE:
|
|
||||||
return ((num_slices, m, k), b_shape, (m, n * num_slices))
|
|
||||||
|
|
||||||
raise ValueError(f"Unrecognized op_type {self}")
|
raise ValueError(f"Unrecognized op_type {self}")
|
||||||
|
|
||||||
def bench_fn(self) -> Callable:
|
def bench_fn(self) -> Callable:
|
||||||
|
if self == OpType.LORA_SHRINK:
|
||||||
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
|
return lora_shrink
|
||||||
for x in kwargs_list:
|
if self == OpType.LORA_EXPAND:
|
||||||
bgmv_expand_slice(**x)
|
return lora_expand
|
||||||
|
|
||||||
if self == OpType.SGMV_SHRINK:
|
|
||||||
return sgmv_shrink
|
|
||||||
if self == OpType.SGMV_EXPAND:
|
|
||||||
return sgmv_expand
|
|
||||||
if self == OpType.BGMV_SHRINK:
|
|
||||||
return bgmv_shrink
|
|
||||||
if self == OpType.BGMV_EXPAND:
|
|
||||||
return bgmv_expand
|
|
||||||
if self == OpType.BGMV_EXPAND_SLICE:
|
|
||||||
return emulate_bgmv_expand_slice
|
|
||||||
if self == OpType.V1_SHRINK:
|
|
||||||
return v1_shrink
|
|
||||||
if self == OpType.V1_EXPAND:
|
|
||||||
return v1_expand
|
|
||||||
|
|
||||||
raise ValueError(f"Unrecognized optype {self}")
|
raise ValueError(f"Unrecognized optype {self}")
|
||||||
|
|
||||||
@ -318,34 +245,13 @@ class OpType(Enum):
|
|||||||
"""
|
"""
|
||||||
w_dtype = lora_weights[0].dtype
|
w_dtype = lora_weights[0].dtype
|
||||||
num_slices = len(lora_weights)
|
num_slices = len(lora_weights)
|
||||||
if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
|
if self in [OpType.LORA_SHRINK]:
|
||||||
for slice_idx in range(num_slices):
|
for slice_idx in range(num_slices):
|
||||||
ref_group_gemm(ref_out=output[slice_idx, :],
|
ref_group_gemm(ref_out=output[slice_idx, :],
|
||||||
input=input,
|
input=input,
|
||||||
lora_weights=lora_weights[slice_idx],
|
lora_weights=lora_weights[slice_idx],
|
||||||
**kwargs)
|
**kwargs)
|
||||||
elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
|
elif self in [OpType.LORA_EXPAND]:
|
||||||
hidden_size = lora_weights[0].shape[1]
|
|
||||||
for slice_idx in range(num_slices):
|
|
||||||
slice_offset = slice_idx * hidden_size
|
|
||||||
ref_group_gemm(
|
|
||||||
ref_out=output[:, slice_offset:slice_offset + hidden_size],
|
|
||||||
input=input[slice_idx].clone().to(dtype=w_dtype),
|
|
||||||
lora_weights=lora_weights[slice_idx],
|
|
||||||
**kwargs)
|
|
||||||
elif self == OpType.BGMV_SHRINK:
|
|
||||||
assert num_slices == 1
|
|
||||||
ref_group_gemm(ref_out=output,
|
|
||||||
input=input,
|
|
||||||
lora_weights=lora_weights[0],
|
|
||||||
**kwargs)
|
|
||||||
elif self == OpType.BGMV_EXPAND:
|
|
||||||
assert num_slices == 1
|
|
||||||
ref_group_gemm(ref_out=output,
|
|
||||||
input=input.clone().to(dtype=w_dtype),
|
|
||||||
lora_weights=lora_weights[0],
|
|
||||||
**kwargs)
|
|
||||||
elif self == OpType.BGMV_EXPAND_SLICE:
|
|
||||||
hidden_size = lora_weights[0].shape[1]
|
hidden_size = lora_weights[0].shape[1]
|
||||||
for slice_idx in range(num_slices):
|
for slice_idx in range(num_slices):
|
||||||
slice_offset = slice_idx * hidden_size
|
slice_offset = slice_idx * hidden_size
|
||||||
@ -411,13 +317,11 @@ class BenchmarkTensors:
|
|||||||
input: torch.Tensor
|
input: torch.Tensor
|
||||||
lora_weights_lst: list[torch.Tensor]
|
lora_weights_lst: list[torch.Tensor]
|
||||||
output: torch.Tensor
|
output: torch.Tensor
|
||||||
# metadata tensors
|
# LoRA kernel metadata
|
||||||
|
lora_kernel_meta: LoRAKernelMeta
|
||||||
|
# Metadata tensors used in testing correctness
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
seq_start_loc: torch.Tensor
|
|
||||||
prompt_lora_mapping: torch.Tensor
|
prompt_lora_mapping: torch.Tensor
|
||||||
token_lora_mapping: torch.Tensor
|
|
||||||
# v1 kernel metadata
|
|
||||||
v1_kernel_meta: Optional[V1KernelMeta] = None
|
|
||||||
|
|
||||||
def io_types(self) -> str:
|
def io_types(self) -> str:
|
||||||
return (f"{dtype_to_str(self.input.dtype)}x"
|
return (f"{dtype_to_str(self.input.dtype)}x"
|
||||||
@ -444,35 +348,29 @@ class BenchmarkTensors:
|
|||||||
assert ctx.num_active_loras <= ctx.num_loras
|
assert ctx.num_active_loras <= ctx.num_loras
|
||||||
total_tokens = ctx.batch_size * ctx.seq_length
|
total_tokens = ctx.batch_size * ctx.seq_length
|
||||||
|
|
||||||
|
# Make metadata tensors involved in correctness testing.
|
||||||
# Prepare seq lens tensor
|
# Prepare seq lens tensor
|
||||||
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
|
seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1,
|
||||||
(ctx.batch_size, ))
|
(ctx.batch_size, ))
|
||||||
# Prepare seq_start_loc tensor
|
|
||||||
seq_start_loc_tensor = torch.cumsum(torch.tensor(
|
|
||||||
[0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
|
||||||
dim=0)
|
|
||||||
assert total_tokens == seq_len_tensor.sum()
|
assert total_tokens == seq_len_tensor.sum()
|
||||||
# Prepare prompt lora indices tensor
|
# Prepare prompt lora indices tensor
|
||||||
prompt_lora_indices_tensor = make_prompt_lora_mapping(
|
prompt_lora_indices_tensor = make_prompt_lora_mapping(
|
||||||
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
|
ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu")
|
||||||
# Prepare token lora indices tensor
|
|
||||||
|
# Make LoRAKernelMeta
|
||||||
token_lora_indices_tensor = make_token_lora_mapping(
|
token_lora_indices_tensor = make_token_lora_mapping(
|
||||||
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
|
total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
|
||||||
seq_len_tensor, "cpu")
|
seq_len_tensor, "cpu")
|
||||||
|
lora_kernel_meta = LoRAKernelMeta.make(
|
||||||
v1_kernel_meta = None
|
max_loras=ctx.num_loras,
|
||||||
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
|
max_num_tokens=token_lora_indices_tensor.size(0),
|
||||||
v1_kernel_meta = V1KernelMeta.make(
|
device="cpu")
|
||||||
max_loras=ctx.num_loras,
|
lora_kernel_meta.prepare_tensors(
|
||||||
max_num_tokens=token_lora_indices_tensor.size(0),
|
token_lora_mapping=token_lora_indices_tensor)
|
||||||
device="cpu")
|
|
||||||
v1_kernel_meta.prepare_tensors(
|
|
||||||
token_lora_mapping=token_lora_indices_tensor)
|
|
||||||
|
|
||||||
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
|
return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
|
||||||
seq_len_tensor, seq_start_loc_tensor,
|
lora_kernel_meta, seq_len_tensor,
|
||||||
prompt_lora_indices_tensor,
|
prompt_lora_indices_tensor)
|
||||||
token_lora_indices_tensor, v1_kernel_meta)
|
|
||||||
|
|
||||||
def sanity_check(self) -> None:
|
def sanity_check(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -482,9 +380,9 @@ class BenchmarkTensors:
|
|||||||
# check metadata tensors
|
# check metadata tensors
|
||||||
assert torch.sum(self.seq_lens) == num_tokens
|
assert torch.sum(self.seq_lens) == num_tokens
|
||||||
num_seqs = self.seq_lens.shape[0]
|
num_seqs = self.seq_lens.shape[0]
|
||||||
assert self.seq_start_loc.shape[0] == num_seqs
|
#assert self.seq_start_loc.shape[0] == num_seqs
|
||||||
assert self.prompt_lora_mapping.shape[0] == num_seqs
|
assert self.prompt_lora_mapping.shape[0] == num_seqs
|
||||||
assert self.token_lora_mapping.shape[0] == num_tokens
|
assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens
|
||||||
|
|
||||||
def to_device(self, device: str):
|
def to_device(self, device: str):
|
||||||
"""
|
"""
|
||||||
@ -499,220 +397,27 @@ class BenchmarkTensors:
|
|||||||
self.input = to_device(self.input)
|
self.input = to_device(self.input)
|
||||||
self.output = to_device(self.output)
|
self.output = to_device(self.output)
|
||||||
self.seq_lens = to_device(self.seq_lens)
|
self.seq_lens = to_device(self.seq_lens)
|
||||||
self.seq_start_loc = to_device(self.seq_start_loc)
|
|
||||||
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
|
self.prompt_lora_mapping = to_device(self.prompt_lora_mapping)
|
||||||
self.token_lora_mapping = to_device(self.token_lora_mapping)
|
|
||||||
for i in range(len(self.lora_weights_lst)):
|
for i in range(len(self.lora_weights_lst)):
|
||||||
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
|
||||||
|
|
||||||
# v1 meta
|
# LoRA meta
|
||||||
if self.v1_kernel_meta:
|
for field_name in LoRAKernelMeta.__dataclass_fields__:
|
||||||
for field_name in V1KernelMeta.__dataclass_fields__:
|
field = getattr(self.lora_kernel_meta, field_name)
|
||||||
field = getattr(self.v1_kernel_meta, field_name)
|
assert isinstance(field, torch.Tensor)
|
||||||
assert isinstance(field, torch.Tensor)
|
setattr(self.lora_kernel_meta, field_name, to_device(field))
|
||||||
setattr(self.v1_kernel_meta, field_name, to_device(field))
|
|
||||||
|
|
||||||
def metadata(self) -> tuple[int, int, int]:
|
def metadata(self) -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
Return num_seqs, num_tokens and max_seq_len
|
Return num_seqs, num_tokens and max_seq_len
|
||||||
"""
|
"""
|
||||||
num_seqs = self.seq_lens.shape[0]
|
num_seqs = self.seq_lens.shape[0]
|
||||||
num_tokens = self.token_lora_mapping.shape[0]
|
num_tokens = self.lora_kernel_meta.token_lora_mapping.shape[0]
|
||||||
max_seq_len = torch.max(self.seq_lens).item()
|
max_seq_len = torch.max(self.seq_lens).item()
|
||||||
num_slices = len(self.lora_weights_lst)
|
num_slices = len(self.lora_weights_lst)
|
||||||
return num_seqs, num_tokens, max_seq_len, num_slices
|
return num_seqs, num_tokens, max_seq_len, num_slices
|
||||||
|
|
||||||
def convert_to_sgmv_benchmark_tensors(self):
|
def as_lora_shrink_kwargs(self) -> dict[str, Any]:
|
||||||
"""
|
|
||||||
For sgmv punica kernels, when consecutive sequences have the
|
|
||||||
same LoRA ID, we just merge them together.
|
|
||||||
This happens in punica.py::compute_metadata
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Collapse seq_lens and seq_start_loc
|
|
||||||
_, seq_lens = torch.unique_consecutive(self.token_lora_mapping,
|
|
||||||
return_counts=True)
|
|
||||||
cum_result = torch.cumsum(seq_lens, dim=0)
|
|
||||||
seq_start_loc = torch.zeros_like(seq_lens)
|
|
||||||
seq_start_loc[1:].copy_(cum_result[:-1])
|
|
||||||
|
|
||||||
# Collapse prompt mapping
|
|
||||||
prompt_lora_mapping = torch.unique_consecutive(
|
|
||||||
self.prompt_lora_mapping)
|
|
||||||
|
|
||||||
assert torch.sum(seq_lens) == torch.sum(self.seq_lens), \
|
|
||||||
f"dont match - new {torch.sum(seq_lens)} vs {torch.sum(self.seq_lens)}"
|
|
||||||
|
|
||||||
self.prompt_lora_mapping = prompt_lora_mapping.to(
|
|
||||||
dtype=self.prompt_lora_mapping.dtype)
|
|
||||||
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
|
|
||||||
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
|
|
||||||
|
|
||||||
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
|
|
||||||
self.convert_to_sgmv_benchmark_tensors()
|
|
||||||
self.sanity_check()
|
|
||||||
self.to_device(self.input.device)
|
|
||||||
|
|
||||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
|
||||||
|
|
||||||
# Sanity check matrix shapes.
|
|
||||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
||||||
0].shape, self.output.shape
|
|
||||||
# Expected input shape [num_tokens, hidden_size]
|
|
||||||
assert len(i_shape) == 2
|
|
||||||
assert i_shape[0] == num_tokens
|
|
||||||
hidden_size = i_shape[1]
|
|
||||||
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
|
||||||
assert len(lw_shape) == 3
|
|
||||||
assert lw_shape[2] == hidden_size
|
|
||||||
lora_rank = lw_shape[1]
|
|
||||||
# Expected output shape [num_slices, num_tokens, lora_rank]
|
|
||||||
assert len(o_shape) == 3
|
|
||||||
assert o_shape == (num_slices, num_tokens, lora_rank)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'inputs': self.input,
|
|
||||||
'lora_a_weights': self.lora_weights_lst,
|
|
||||||
'output_tensor': self.output,
|
|
||||||
'b_seq_start_loc': self.seq_start_loc,
|
|
||||||
'seq_len_tensor': self.seq_lens,
|
|
||||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
|
||||||
'batches': num_seqs,
|
|
||||||
'max_seq_length': max_seq_len,
|
|
||||||
'token_nums': num_tokens,
|
|
||||||
'scaling': 1.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
|
||||||
|
|
||||||
self.convert_to_sgmv_benchmark_tensors()
|
|
||||||
self.sanity_check()
|
|
||||||
self.to_device(self.input.device)
|
|
||||||
|
|
||||||
num_seqs, num_tokens, max_seq_len, num_slices = self.metadata()
|
|
||||||
|
|
||||||
# Sanity check matrix shapes.
|
|
||||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
||||||
0].shape, self.output.shape
|
|
||||||
# Expected input shape : [num_slices, num_tokens, lora_rank]
|
|
||||||
assert len(i_shape) == 3
|
|
||||||
assert i_shape[0] == num_slices
|
|
||||||
assert i_shape[1] == num_tokens
|
|
||||||
lora_rank = i_shape[2]
|
|
||||||
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
|
|
||||||
assert len(lw_shape) == 3
|
|
||||||
assert lw_shape[2] == lora_rank
|
|
||||||
hidden_size = lw_shape[1]
|
|
||||||
# Expected output shape : [num_tokens, hidden_size * num_slices]
|
|
||||||
assert len(o_shape) == 2
|
|
||||||
assert o_shape == (num_tokens, hidden_size * num_slices)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'inputs': self.input,
|
|
||||||
'lora_b_weights': self.lora_weights_lst,
|
|
||||||
'output_tensor': self.output,
|
|
||||||
'b_seq_start_loc': self.seq_start_loc,
|
|
||||||
'seq_len_tensor': self.seq_lens,
|
|
||||||
'lora_indices_tensor': self.prompt_lora_mapping,
|
|
||||||
'batches': num_seqs,
|
|
||||||
'max_seq_length': max_seq_len,
|
|
||||||
'token_nums': num_tokens,
|
|
||||||
'offset_start': 0,
|
|
||||||
'add_inputs': add_inputs,
|
|
||||||
}
|
|
||||||
|
|
||||||
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
|
|
||||||
assert len(self.lora_weights_lst) == 1
|
|
||||||
self.to_device(self.input.device)
|
|
||||||
|
|
||||||
_, num_tokens, _, _ = self.metadata()
|
|
||||||
# Sanity check shapes
|
|
||||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
||||||
0].shape, self.output.shape
|
|
||||||
# Expected input shape [num_tokens, hidden_size]
|
|
||||||
assert len(i_shape) == 2
|
|
||||||
assert i_shape[0] == num_tokens
|
|
||||||
hidden_size = i_shape[1]
|
|
||||||
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
|
|
||||||
assert len(lw_shape) == 3
|
|
||||||
assert lw_shape[2] == hidden_size
|
|
||||||
lora_rank = lw_shape[1]
|
|
||||||
# Expected output shape [num_tokens, lora_rank]
|
|
||||||
assert len(o_shape) == 2
|
|
||||||
assert o_shape == (num_tokens, lora_rank)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'inputs': self.input,
|
|
||||||
'lora_a_weights': self.lora_weights_lst[0],
|
|
||||||
'output_tensor': self.output,
|
|
||||||
'lora_indices_tensor': self.token_lora_mapping,
|
|
||||||
'scaling': 1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
def as_bgmv_expand_kwargs(self, add_inputs: bool):
|
|
||||||
assert len(self.lora_weights_lst) == 1
|
|
||||||
self.to_device(self.input.device)
|
|
||||||
|
|
||||||
_, num_tokens, _, _ = self.metadata()
|
|
||||||
# Sanity check shapes
|
|
||||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
||||||
0].shape, self.output.shape
|
|
||||||
# Expected input shape [num_tokens, lora_rank]
|
|
||||||
assert len(i_shape) == 2
|
|
||||||
assert i_shape[0] == num_tokens
|
|
||||||
lora_rank = i_shape[1]
|
|
||||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
|
||||||
assert len(lw_shape) == 3
|
|
||||||
assert lw_shape[2] == lora_rank
|
|
||||||
hidden_size = lw_shape[1]
|
|
||||||
# Expected output shape [num_tokens, hidden_size]
|
|
||||||
assert len(o_shape) == 2
|
|
||||||
assert o_shape == (num_tokens, hidden_size)
|
|
||||||
|
|
||||||
return {
|
|
||||||
'inputs': self.input,
|
|
||||||
'lora_b_weights': self.lora_weights_lst[0],
|
|
||||||
'output_tensor': self.output,
|
|
||||||
'lora_indices_tensor': self.token_lora_mapping,
|
|
||||||
'add_inputs': add_inputs
|
|
||||||
}
|
|
||||||
|
|
||||||
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
|
||||||
|
|
||||||
_, num_tokens, _, num_slices = self.metadata()
|
|
||||||
# Sanity check shapes
|
|
||||||
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
|
|
||||||
0].shape, self.output.shape
|
|
||||||
# Expected input shape [num_slices, num_tokens, lora_rank]
|
|
||||||
assert len(i_shape) == 3
|
|
||||||
assert i_shape[0] == num_slices
|
|
||||||
assert i_shape[1] == num_tokens
|
|
||||||
lora_rank = i_shape[2]
|
|
||||||
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
|
|
||||||
assert len(lw_shape) == 3
|
|
||||||
assert lw_shape[2] == lora_rank
|
|
||||||
hidden_size = lw_shape[1]
|
|
||||||
# Expected output shape [num_tokens, hidden_size * num_slices]
|
|
||||||
assert len(o_shape) == 2
|
|
||||||
assert o_shape == (num_tokens, hidden_size * num_slices)
|
|
||||||
|
|
||||||
self.to_device(self.input.device)
|
|
||||||
|
|
||||||
kwargs_list = []
|
|
||||||
for i in range(num_slices):
|
|
||||||
kwargs_list.append({
|
|
||||||
'inputs': self.input[i],
|
|
||||||
'lora_b_weights': self.lora_weights_lst[i],
|
|
||||||
'output_tensor': self.output,
|
|
||||||
'lora_indices_tensor': self.token_lora_mapping,
|
|
||||||
'slice_offset': i * hidden_size,
|
|
||||||
'slice_size': hidden_size,
|
|
||||||
'add_inputs': add_inputs,
|
|
||||||
})
|
|
||||||
return {'kwargs_list': kwargs_list}
|
|
||||||
|
|
||||||
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
|
|
||||||
assert self.v1_kernel_meta is not None
|
|
||||||
self.sanity_check()
|
self.sanity_check()
|
||||||
self.to_device(self.input.device)
|
self.to_device(self.input.device)
|
||||||
|
|
||||||
@ -737,17 +442,16 @@ class BenchmarkTensors:
|
|||||||
'inputs': self.input,
|
'inputs': self.input,
|
||||||
'lora_a_weights': self.lora_weights_lst,
|
'lora_a_weights': self.lora_weights_lst,
|
||||||
'output_tensor': self.output,
|
'output_tensor': self.output,
|
||||||
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
|
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||||
'token_indices_sorted_by_lora_ids':
|
'token_indices_sorted_by_lora_ids':
|
||||||
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
|
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||||
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
|
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||||
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
|
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||||
'lora_ids': self.v1_kernel_meta.active_lora_ids,
|
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||||
'scaling': 1.0,
|
'scaling': 1.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
|
||||||
assert self.v1_kernel_meta is not None
|
|
||||||
self.sanity_check()
|
self.sanity_check()
|
||||||
self.to_device(self.input.device)
|
self.to_device(self.input.device)
|
||||||
|
|
||||||
@ -773,12 +477,12 @@ class BenchmarkTensors:
|
|||||||
'inputs': self.input,
|
'inputs': self.input,
|
||||||
'lora_b_weights': self.lora_weights_lst,
|
'lora_b_weights': self.lora_weights_lst,
|
||||||
'output_tensor': self.output,
|
'output_tensor': self.output,
|
||||||
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
|
'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping,
|
||||||
'token_indices_sorted_by_lora_ids':
|
'token_indices_sorted_by_lora_ids':
|
||||||
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
|
self.lora_kernel_meta.token_indices_sorted_by_lora_ids,
|
||||||
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
|
'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora,
|
||||||
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
|
'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc,
|
||||||
'lora_ids': self.v1_kernel_meta.active_lora_ids,
|
'lora_ids': self.lora_kernel_meta.active_lora_ids,
|
||||||
'offset_start': 0,
|
'offset_start': 0,
|
||||||
'add_inputs': add_inputs,
|
'add_inputs': add_inputs,
|
||||||
}
|
}
|
||||||
@ -791,20 +495,10 @@ class BenchmarkTensors:
|
|||||||
else:
|
else:
|
||||||
assert add_inputs is not None
|
assert add_inputs is not None
|
||||||
|
|
||||||
if op_type == OpType.SGMV_SHRINK:
|
if op_type == OpType.LORA_SHRINK:
|
||||||
return self.as_sgmv_shrink_kwargs()
|
return self.as_lora_shrink_kwargs()
|
||||||
if op_type == OpType.SGMV_EXPAND:
|
if op_type == OpType.LORA_EXPAND:
|
||||||
return self.as_sgmv_expand_kwargs(add_inputs)
|
return self.as_lora_expand_kwargs(add_inputs)
|
||||||
if op_type == OpType.BGMV_SHRINK:
|
|
||||||
return self.as_bgmv_shrink_kwargs()
|
|
||||||
if op_type == OpType.BGMV_EXPAND:
|
|
||||||
return self.as_bgmv_expand_kwargs(add_inputs)
|
|
||||||
if op_type == OpType.BGMV_EXPAND_SLICE:
|
|
||||||
return self.as_bgmv_expand_slice_kwargs(add_inputs)
|
|
||||||
if op_type == OpType.V1_SHRINK:
|
|
||||||
return self.as_v1_shrink_kwargs()
|
|
||||||
if op_type == OpType.V1_EXPAND:
|
|
||||||
return self.as_v1_expand_kwargs(add_inputs)
|
|
||||||
raise ValueError(f"Unrecognized optype {self}")
|
raise ValueError(f"Unrecognized optype {self}")
|
||||||
|
|
||||||
def test_correctness(self, op_type: OpType,
|
def test_correctness(self, op_type: OpType,
|
||||||
@ -993,10 +687,6 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
|
|||||||
for bench_ctx in bench_ctxs:
|
for bench_ctx in bench_ctxs:
|
||||||
for seq_len in args.seq_lengths:
|
for seq_len in args.seq_lengths:
|
||||||
bench_ops: list[OpType] = args.op_types
|
bench_ops: list[OpType] = args.op_types
|
||||||
if seq_len > 1:
|
|
||||||
# bench only prefill ops
|
|
||||||
bench_ops = [op for op in args.op_types if op.is_prefill_op()]
|
|
||||||
|
|
||||||
seq_len_timers = []
|
seq_len_timers = []
|
||||||
for bench_op in bench_ops:
|
for bench_op in bench_ops:
|
||||||
for num_slices in bench_op.num_slices():
|
for num_slices in bench_op.num_slices():
|
||||||
@ -1206,13 +896,13 @@ Benchmark LoRA kernels:
|
|||||||
{use_cuda_graph_recommendation()}
|
{use_cuda_graph_recommendation()}
|
||||||
|
|
||||||
list_bench example:
|
list_bench example:
|
||||||
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||||
|
|
||||||
model_bench example:
|
model_bench example:
|
||||||
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
|
||||||
|
|
||||||
range_bench example:
|
range_bench example:
|
||||||
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
|
||||||
""", # noqa: E501
|
""", # noqa: E501
|
||||||
formatter_class=argparse.RawTextHelpFormatter)
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
|
|
||||||
|
|||||||
@ -30,19 +30,18 @@ class BenchmarkConfig(TypedDict):
|
|||||||
num_stages: int
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
def benchmark_config(
|
def benchmark_config(config: BenchmarkConfig,
|
||||||
config: BenchmarkConfig,
|
num_tokens: int,
|
||||||
num_tokens: int,
|
num_experts: int,
|
||||||
num_experts: int,
|
shard_intermediate_size: int,
|
||||||
shard_intermediate_size: int,
|
hidden_size: int,
|
||||||
hidden_size: int,
|
topk: int,
|
||||||
topk: int,
|
dtype: torch.dtype,
|
||||||
dtype: torch.dtype,
|
use_fp8_w8a8: bool,
|
||||||
use_fp8_w8a8: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int8_w8a16: bool,
|
num_iters: int = 100,
|
||||||
num_iters: int = 100,
|
block_quant_shape: List[int] = None,
|
||||||
block_quant_shape: List[int] = None,
|
use_deep_gemm: bool = False) -> float:
|
||||||
) -> float:
|
|
||||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
if use_int8_w8a16:
|
if use_int8_w8a16:
|
||||||
@ -115,22 +114,41 @@ def benchmark_config(
|
|||||||
def run():
|
def run():
|
||||||
from vllm.model_executor.layers.fused_moe import override_config
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
fused_moe(
|
if use_deep_gemm:
|
||||||
x,
|
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
|
||||||
w1,
|
False)
|
||||||
w2,
|
return fused_experts(
|
||||||
input_gating,
|
x,
|
||||||
topk,
|
w1,
|
||||||
renormalize=True,
|
w2,
|
||||||
inplace=True,
|
topk_weights,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
topk_ids,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
inplace=True,
|
||||||
w1_scale=w1_scale,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
w2_scale=w2_scale,
|
w1_scale=w1_scale,
|
||||||
a1_scale=a1_scale,
|
w2_scale=w2_scale,
|
||||||
a2_scale=a2_scale,
|
a1_scale=a1_scale,
|
||||||
block_shape=block_quant_shape,
|
a2_scale=a2_scale,
|
||||||
)
|
block_shape=block_quant_shape,
|
||||||
|
allow_deep_gemm=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fused_moe(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
input_gating,
|
||||||
|
topk,
|
||||||
|
renormalize=True,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_quant_shape,
|
||||||
|
)
|
||||||
|
|
||||||
# JIT compilation & warmup
|
# JIT compilation & warmup
|
||||||
run()
|
run()
|
||||||
@ -366,6 +384,7 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
block_quant_shape: List[int] = None,
|
block_quant_shape: List[int] = None,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
) -> tuple[dict[str, int], float]:
|
) -> tuple[dict[str, int], float]:
|
||||||
current_platform.seed_everything(self.seed)
|
current_platform.seed_everything(self.seed)
|
||||||
dtype_str = get_config_dtype_str(dtype,
|
dtype_str = get_config_dtype_str(dtype,
|
||||||
@ -396,7 +415,8 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=100,
|
num_iters=100,
|
||||||
block_quant_shape=block_quant_shape)
|
block_quant_shape=block_quant_shape,
|
||||||
|
use_deep_gemm=use_deep_gemm)
|
||||||
return config, kernel_time
|
return config, kernel_time
|
||||||
|
|
||||||
def tune(
|
def tune(
|
||||||
@ -411,6 +431,7 @@ class BenchmarkWorker:
|
|||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
search_space: list[dict[str, int]],
|
search_space: list[dict[str, int]],
|
||||||
block_quant_shape: list[int],
|
block_quant_shape: list[int],
|
||||||
|
use_deep_gemm: bool,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
best_config = None
|
best_config = None
|
||||||
best_time = float("inf")
|
best_time = float("inf")
|
||||||
@ -436,7 +457,8 @@ class BenchmarkWorker:
|
|||||||
use_fp8_w8a8,
|
use_fp8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
num_iters=20,
|
num_iters=20,
|
||||||
block_quant_shape=block_quant_shape)
|
block_quant_shape=block_quant_shape,
|
||||||
|
use_deep_gemm=use_deep_gemm)
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
# Some configurations may be invalid and fail to compile.
|
# Some configurations may be invalid and fail to compile.
|
||||||
continue
|
continue
|
||||||
@ -531,6 +553,9 @@ def main(args: argparse.Namespace):
|
|||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
else:
|
else:
|
||||||
|
if not hasattr(config, "hidden_size"):
|
||||||
|
# Support for llama4
|
||||||
|
config = config.text_config
|
||||||
# Default: Mixtral.
|
# Default: Mixtral.
|
||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
@ -550,6 +575,8 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
batch_sizes = [args.batch_size]
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
use_deep_gemm = bool(args.use_deep_gemm)
|
||||||
|
|
||||||
ray.init()
|
ray.init()
|
||||||
num_gpus = int(ray.available_resources()["GPU"])
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
@ -572,10 +599,10 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
configs = _distribute(
|
configs = _distribute(
|
||||||
"tune",
|
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space,
|
||||||
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
|
block_quant_shape, use_deep_gemm)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
best_configs = {
|
best_configs = {
|
||||||
M: sort_config(config)
|
M: sort_config(config)
|
||||||
for M, config in zip(batch_sizes, configs)
|
for M, config in zip(batch_sizes, configs)
|
||||||
@ -589,7 +616,7 @@ def main(args: argparse.Namespace):
|
|||||||
outputs = _distribute(
|
outputs = _distribute(
|
||||||
"benchmark",
|
"benchmark",
|
||||||
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
|
||||||
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
|
use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm)
|
||||||
for batch_size in batch_sizes])
|
for batch_size in batch_sizes])
|
||||||
|
|
||||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
@ -611,6 +638,7 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||||
default="auto")
|
default="auto")
|
||||||
|
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--batch-size", type=int, required=False)
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
parser.add_argument("--tune", action="store_true")
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
|||||||
@ -7,10 +7,13 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
|
||||||
create_kv_caches_with_random)
|
create_kv_caches_with_random)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
NUM_BLOCKS = 128 * 1024
|
NUM_BLOCKS = 128 * 1024
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
PARTITION_SIZE_ROCM = 256
|
PARTITION_SIZE_ROCM = 256
|
||||||
@ -193,6 +196,9 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
logger.warning("This script benchmarks the paged attention kernel. "
|
||||||
|
"By default this is no longer used in vLLM inference.")
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the paged attention kernel.")
|
description="Benchmark the paged attention kernel.")
|
||||||
parser.add_argument("--version",
|
parser.add_argument("--version",
|
||||||
|
|||||||
@ -75,3 +75,19 @@ WEIGHT_SHAPES = {
|
|||||||
[7168, 8192],
|
[7168, 8192],
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
WEIGHT_SHAPES_MOE = {
|
||||||
|
"nm-testing/Mixtral-8x7B-Instruct-v0.1": [
|
||||||
|
[8, 2, 4096, 28672],
|
||||||
|
[8, 2, 14336, 4096],
|
||||||
|
],
|
||||||
|
"nm-testing/deepseekv2-lite": [
|
||||||
|
[64, 6, 2048, 1408],
|
||||||
|
],
|
||||||
|
"ibm-granite/granite-3.0-1b-a400m": [
|
||||||
|
[32, 8, 1024, 1024],
|
||||||
|
],
|
||||||
|
"ibm-granite/granite-3.0-3b-a800m": [
|
||||||
|
[40, 8, 1024, 1536],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|||||||
420
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
420
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
@ -0,0 +1,420 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Adapted from sglang quantization/tuning_block_wise_kernel.py
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import triton
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
_w8a8_block_fp8_matmul)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
assert current_platform.is_cuda(
|
||||||
|
), "Only support tune w8a8 block fp8 kernel on CUDA device."
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def w8a8_block_matmul(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
config: dict[str, Any],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""This function performs matrix multiplication with
|
||||||
|
block-wise quantization.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
A: The input tensor, e.g., activation.
|
||||||
|
B: The input tensor, e.g., weight.
|
||||||
|
As: The per-token-group quantization scale for `A`.
|
||||||
|
Bs: The per-block quantization scale for `B`.
|
||||||
|
block_size: The block size for per-block quantization.
|
||||||
|
It should be 2-dim, e.g., [128, 128].
|
||||||
|
output_dytpe: The dtype of the returned tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The result of matmul.
|
||||||
|
"""
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||||
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
N, K = B.shape
|
||||||
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||||
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = A.shape[:-1] + (N, )
|
||||||
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) *
|
||||||
|
triton.cdiv(N, META["BLOCK_SIZE_N"]), )
|
||||||
|
|
||||||
|
if A.dtype == torch.float8_e4m3fn:
|
||||||
|
kernel = _w8a8_block_fp8_matmul
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||||
|
|
||||||
|
kernel[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
A.stride(-2),
|
||||||
|
A.stride(-1),
|
||||||
|
B.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
C.stride(-2),
|
||||||
|
C.stride(-1),
|
||||||
|
As.stride(-2),
|
||||||
|
As.stride(-1),
|
||||||
|
Bs.stride(1),
|
||||||
|
Bs.stride(0),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound():
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append({
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
})
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_shapes(tp_size):
|
||||||
|
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3.
|
||||||
|
# Modify them, if you tune for another different model.
|
||||||
|
# cannot TP
|
||||||
|
total = [
|
||||||
|
(512 + 64, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(7168, 16384),
|
||||||
|
(7168, 18432),
|
||||||
|
]
|
||||||
|
# N can TP
|
||||||
|
n_tp = [
|
||||||
|
(18432 * 2, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(24576, 1536),
|
||||||
|
(12288, 7168),
|
||||||
|
(4096, 7168),
|
||||||
|
]
|
||||||
|
# K can TP
|
||||||
|
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||||
|
|
||||||
|
weight_shapes = []
|
||||||
|
for t in total:
|
||||||
|
weight_shapes.append(t)
|
||||||
|
for n_t in n_tp:
|
||||||
|
new_t = (n_t[0] // tp_size, n_t[1])
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
for k_t in k_tp:
|
||||||
|
new_t = (k_t[0], k_t[1] // tp_size)
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
return weight_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(A,
|
||||||
|
B,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
block_size,
|
||||||
|
config,
|
||||||
|
out_dtype=torch.float16,
|
||||||
|
num_iters=10):
|
||||||
|
|
||||||
|
def run():
|
||||||
|
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# JIT complication & warmup
|
||||||
|
for _ in range(5):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
run()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
|
||||||
|
if input_type == "fp8":
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
A_fp32 = (
|
||||||
|
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||||
|
fp8_max)
|
||||||
|
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
B_fp32 = (
|
||||||
|
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 *
|
||||||
|
fp8_max)
|
||||||
|
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Currently, only support tune w8a8 block fp8 kernel.")
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
|
As = torch.rand(M, k_tiles, dtype=torch.float32,
|
||||||
|
device="cuda") * factor_for_scale
|
||||||
|
Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") *
|
||||||
|
factor_for_scale)
|
||||||
|
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
block_size,
|
||||||
|
config,
|
||||||
|
out_dtype,
|
||||||
|
num_iters=10,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
|
||||||
|
assert best_config is not None
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
configs,
|
||||||
|
save_path,
|
||||||
|
input_type="fp8",
|
||||||
|
) -> None:
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||||
|
json_file_name = (
|
||||||
|
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
|
||||||
|
f"block_shape=[{block_n},{block_k}].json")
|
||||||
|
|
||||||
|
config_file_path = os.path.join(save_path, json_file_name)
|
||||||
|
print(f"Writing best config to {config_file_path}...")
|
||||||
|
|
||||||
|
with open(config_file_path, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def tune_on_gpu(args_dict):
|
||||||
|
"""Run tuning on a specific GPU."""
|
||||||
|
gpu_id = args_dict["gpu_id"]
|
||||||
|
batch_sizes = args_dict["batch_sizes"]
|
||||||
|
weight_shapes = args_dict["weight_shapes"]
|
||||||
|
args = args_dict["args"]
|
||||||
|
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
|
||||||
|
|
||||||
|
block_n = args.block_n
|
||||||
|
block_k = args.block_k
|
||||||
|
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||||
|
save_path = args.save_path
|
||||||
|
input_type = args.input_type
|
||||||
|
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
search_space = [
|
||||||
|
config for config in search_space
|
||||||
|
if block_k % config["BLOCK_SIZE_K"] == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
|
||||||
|
N, K = shape[0], shape[1]
|
||||||
|
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||||
|
benchmark_results = [
|
||||||
|
tune(
|
||||||
|
batch_size,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
[block_n, block_k],
|
||||||
|
out_dtype,
|
||||||
|
search_space,
|
||||||
|
input_type,
|
||||||
|
) for batch_size in tqdm(batch_sizes,
|
||||||
|
desc=f"GPU {gpu_id} - Batch sizes")
|
||||||
|
]
|
||||||
|
best_configs = {
|
||||||
|
M: config
|
||||||
|
for M, config in zip(batch_sizes, benchmark_results)
|
||||||
|
}
|
||||||
|
save_configs(N, K, block_n, block_k, best_configs, save_path,
|
||||||
|
input_type)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_batch_sizes(batch_sizes, num_gpus):
|
||||||
|
"""Distribute batch sizes across available GPUs."""
|
||||||
|
batches_per_gpu = []
|
||||||
|
for i in range(num_gpus):
|
||||||
|
start_idx = i * len(batch_sizes) // num_gpus
|
||||||
|
end_idx = (i + 1) * len(batch_sizes) // num_gpus
|
||||||
|
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
|
||||||
|
return batches_per_gpu
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
if num_gpus == 0:
|
||||||
|
raise RuntimeError("No GPU available for tuning")
|
||||||
|
print(f"Found {num_gpus} GPUs for parallel tuning")
|
||||||
|
|
||||||
|
torch.cuda.init()
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
32,
|
||||||
|
48,
|
||||||
|
64,
|
||||||
|
96,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
3072,
|
||||||
|
4096,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
num_gpus = 1 # If only one batch size, use only one GPU
|
||||||
|
|
||||||
|
weight_shapes = get_weight_shapes(args.tp_size)
|
||||||
|
|
||||||
|
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
|
||||||
|
|
||||||
|
process_args = []
|
||||||
|
for gpu_id in range(num_gpus):
|
||||||
|
process_args.append({
|
||||||
|
"gpu_id": gpu_id,
|
||||||
|
"batch_sizes": batches_per_gpu[gpu_id],
|
||||||
|
"weight_shapes":
|
||||||
|
weight_shapes, # Each GPU processes all weight shapes
|
||||||
|
"args": args,
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx = mp.get_context("spawn")
|
||||||
|
with ctx.Pool(num_gpus) as pool:
|
||||||
|
pool.map(tune_on_gpu, process_args)
|
||||||
|
|
||||||
|
print("Multi-GPU tuning completed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""
|
||||||
|
Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
|
||||||
|
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
|
||||||
|
Then copy to model_executor/layers/quantization/utils/configs
|
||||||
|
""",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||||
|
parser.add_argument("--input-type",
|
||||||
|
type=str,
|
||||||
|
choices=["fp8"],
|
||||||
|
default="fp8")
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["float32", "float16", "bfloat16", "half"],
|
||||||
|
default="float16",
|
||||||
|
)
|
||||||
|
parser.add_argument("--block-n", type=int, default=128)
|
||||||
|
parser.add_argument("--block-k", type=int, default=128)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--save-path", type=str, default="./")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
@ -1,16 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
PORT=8000
|
|
||||||
MODEL=$1
|
|
||||||
TOKENS=$2
|
|
||||||
|
|
||||||
docker run -e "HF_TOKEN=$HF_TOKEN" --gpus all --shm-size 1g -p $PORT:80 \
|
|
||||||
-v "$PWD/data:/data" \
|
|
||||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
|
||||||
--model-id "$MODEL" \
|
|
||||||
--sharded false \
|
|
||||||
--max-input-length 1024 \
|
|
||||||
--max-total-tokens 2048 \
|
|
||||||
--max-best-of 5 \
|
|
||||||
--max-concurrent-requests 5000 \
|
|
||||||
--max-batch-total-tokens "$TOKENS"
|
|
||||||
@ -54,6 +54,7 @@ for qps in "${QPS_VALUES[@]}"; do
|
|||||||
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||||
--request-rate $qps \
|
--request-rate $qps \
|
||||||
--result-filename "$FILENAME" \
|
--result-filename "$FILENAME" \
|
||||||
|
--tokenizer-mode ${TOKENIZER_MODE:-"auto"} \
|
||||||
--port ${PORT:-8000}
|
--port ${PORT:-8000}
|
||||||
|
|
||||||
echo "Completed benchmark with QPS: $qps"
|
echo "Completed benchmark with QPS: $qps"
|
||||||
|
|||||||
@ -33,8 +33,6 @@ endif()
|
|||||||
|
|
||||||
if(MACOSX_FOUND)
|
if(MACOSX_FOUND)
|
||||||
list(APPEND CXX_COMPILE_FLAGS
|
list(APPEND CXX_COMPILE_FLAGS
|
||||||
"-Xpreprocessor"
|
|
||||||
"-fopenmp"
|
|
||||||
"-DVLLM_CPU_EXTENSION")
|
"-DVLLM_CPU_EXTENSION")
|
||||||
else()
|
else()
|
||||||
list(APPEND CXX_COMPILE_FLAGS
|
list(APPEND CXX_COMPILE_FLAGS
|
||||||
@ -190,12 +188,14 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/cpu/cache.cpp"
|
"csrc/cpu/cache.cpp"
|
||||||
"csrc/cpu/utils.cpp"
|
"csrc/cpu/utils.cpp"
|
||||||
"csrc/cpu/layernorm.cpp"
|
"csrc/cpu/layernorm.cpp"
|
||||||
|
"csrc/cpu/mla_decode.cpp"
|
||||||
"csrc/cpu/pos_encoding.cpp"
|
"csrc/cpu/pos_encoding.cpp"
|
||||||
"csrc/cpu/torch_bindings.cpp")
|
"csrc/cpu/torch_bindings.cpp")
|
||||||
|
|
||||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||||
set(VLLM_EXT_SRC
|
set(VLLM_EXT_SRC
|
||||||
"csrc/cpu/quant.cpp"
|
"csrc/cpu/quant.cpp"
|
||||||
|
"csrc/cpu/shm.cpp"
|
||||||
${VLLM_EXT_SRC})
|
${VLLM_EXT_SRC})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|||||||
@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 9bfa9869829d8c593527eb34c5271d0090f7ccc9
|
GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
|||||||
178
csrc/attention/merge_attn_states.cu
Normal file
178
csrc/attention/merge_attn_states.cu
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
#include <optional>
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "attention_dtypes.h"
|
||||||
|
#include "attention_utils.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||||
|
// can be used to combine partial attention results (in the split-KV case)
|
||||||
|
template <typename scalar_t, const uint NUM_THREADS>
|
||||||
|
__global__ void merge_attn_states_kernel(
|
||||||
|
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||||
|
const float* prefix_lse, const scalar_t* suffix_output,
|
||||||
|
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||||
|
const uint head_size) {
|
||||||
|
using pack_128b_t = uint4;
|
||||||
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
|
const uint threads_per_head = head_size / pack_size;
|
||||||
|
|
||||||
|
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
|
||||||
|
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
|
||||||
|
|
||||||
|
if (global_idx >= token_head_threads) return;
|
||||||
|
|
||||||
|
// global_idx -> token_idx + head_idx + pack_idx
|
||||||
|
const uint token_head_idx = global_idx / threads_per_head;
|
||||||
|
const uint pack_idx = global_idx % threads_per_head;
|
||||||
|
|
||||||
|
const uint token_idx = token_head_idx / num_heads;
|
||||||
|
const uint head_idx = token_head_idx % num_heads;
|
||||||
|
|
||||||
|
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||||
|
const uint head_offset =
|
||||||
|
token_idx * num_heads * head_size + head_idx * head_size;
|
||||||
|
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
||||||
|
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
||||||
|
scalar_t* output_head_ptr = output + head_offset;
|
||||||
|
|
||||||
|
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||||
|
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||||
|
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||||
|
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
|
||||||
|
|
||||||
|
const float max_lse = fmaxf(p_lse, s_lse);
|
||||||
|
p_lse = p_lse - max_lse;
|
||||||
|
s_lse = s_lse - max_lse;
|
||||||
|
const float p_se = expf(p_lse);
|
||||||
|
const float s_se = expf(s_lse);
|
||||||
|
const float out_se = p_se + s_se;
|
||||||
|
const float p_scale = p_se / out_se;
|
||||||
|
const float s_scale = s_se / out_se;
|
||||||
|
|
||||||
|
if (pack_offset < head_size) {
|
||||||
|
// Pack 128b load
|
||||||
|
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||||
|
prefix_head_ptr)[pack_offset / pack_size];
|
||||||
|
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||||
|
suffix_head_ptr)[pack_offset / pack_size];
|
||||||
|
pack_128b_t o_out_pack;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (uint i = 0; i < pack_size; ++i) {
|
||||||
|
// Always use float for FMA to keep high precision.
|
||||||
|
// half(uint16_t), bfloat16, float -> float.
|
||||||
|
const float p_out_f =
|
||||||
|
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||||
|
const float s_out_f =
|
||||||
|
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||||
|
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||||
|
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||||
|
// float -> half(uint16_t), bfloat16, float.
|
||||||
|
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pack 128b storage
|
||||||
|
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||||
|
o_out_pack;
|
||||||
|
}
|
||||||
|
// We only need to write to output_lse once per head.
|
||||||
|
if (output_lse != nullptr && pack_idx == 0) {
|
||||||
|
float out_lse = logf(out_se) + max_lse;
|
||||||
|
output_lse[head_idx * num_tokens + token_idx] = out_lse;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// The following macro is used to dispatch the conversion function based on
|
||||||
|
// the output data type. The FN is a macro that calls a function with
|
||||||
|
// template<typename scalar_t>.
|
||||||
|
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||||
|
{ \
|
||||||
|
if (scalar_dtype == at::ScalarType::Float) { \
|
||||||
|
fn(float); \
|
||||||
|
} else if (scalar_dtype == at::ScalarType::Half) { \
|
||||||
|
fn(uint16_t); \
|
||||||
|
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
||||||
|
fn(__nv_bfloat16); \
|
||||||
|
} else { \
|
||||||
|
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||||
|
{ \
|
||||||
|
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||||
|
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||||
|
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||||
|
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||||
|
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||||
|
num_heads, head_size); \
|
||||||
|
}
|
||||||
|
|
||||||
|
/*@brief Merges the attention states from prefix and suffix
|
||||||
|
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||||
|
*
|
||||||
|
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||||
|
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||||
|
* @param prefix_output [n,h,d] The prefix attention states.
|
||||||
|
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||||
|
* states.
|
||||||
|
* @param suffix_output [n,h,d] The suffix attention states.
|
||||||
|
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
|
||||||
|
* states.
|
||||||
|
*/
|
||||||
|
template <typename scalar_t>
|
||||||
|
void merge_attn_states_launcher(torch::Tensor& output,
|
||||||
|
std::optional<torch::Tensor> output_lse,
|
||||||
|
const torch::Tensor& prefix_output,
|
||||||
|
const torch::Tensor& prefix_lse,
|
||||||
|
const torch::Tensor& suffix_output,
|
||||||
|
const torch::Tensor& suffix_lse) {
|
||||||
|
constexpr uint NUM_THREADS = 128;
|
||||||
|
const uint num_tokens = output.size(0);
|
||||||
|
const uint num_heads = output.size(1);
|
||||||
|
const uint head_size = output.size(2);
|
||||||
|
const uint pack_size = 16 / sizeof(scalar_t);
|
||||||
|
TORCH_CHECK(head_size % pack_size == 0,
|
||||||
|
"headsize must be multiple of pack_size:", pack_size);
|
||||||
|
float* output_lse_ptr = nullptr;
|
||||||
|
if (output_lse.has_value()) {
|
||||||
|
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||||
|
}
|
||||||
|
// Process one pack elements per thread. for float, the
|
||||||
|
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||||
|
const uint threads_per_head = head_size / pack_size;
|
||||||
|
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||||
|
|
||||||
|
dim3 block(NUM_THREADS);
|
||||||
|
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||||
|
|
||||||
|
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||||
|
{ \
|
||||||
|
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
|
||||||
|
prefix_lse, suffix_output, \
|
||||||
|
suffix_lse); \
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge_attn_states(torch::Tensor& output,
|
||||||
|
std::optional<torch::Tensor> output_lse,
|
||||||
|
const torch::Tensor& prefix_output,
|
||||||
|
const torch::Tensor& prefix_lse,
|
||||||
|
const torch::Tensor& suffix_output,
|
||||||
|
const torch::Tensor& suffix_lse) {
|
||||||
|
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||||
|
}
|
||||||
@ -350,8 +350,8 @@ __global__ void concat_and_cache_mla_kernel(
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// KV_T is the stored data type of kv-cache.
|
// KV_T is the data type of key and value tensors.
|
||||||
// CACHE_T is the data type of key and value tensors.
|
// CACHE_T is the stored data type of kv-cache.
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
@ -393,8 +393,8 @@ void reshape_and_cache(
|
|||||||
CALL_RESHAPE_AND_CACHE)
|
CALL_RESHAPE_AND_CACHE)
|
||||||
}
|
}
|
||||||
|
|
||||||
// KV_T is the stored data type of kv-cache.
|
// KV_T is the data type of key and value tensors.
|
||||||
// CACHE_T is the data type of key and value tensors.
|
// CACHE_T is the stored data type of kv-cache.
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
@ -446,8 +446,8 @@ void reshape_and_cache_flash(
|
|||||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||||
}
|
}
|
||||||
|
|
||||||
// KV_T is the stored data type of kv-cache.
|
// KV_T is the data type of key and value tensors.
|
||||||
// CACHE_T is the data type of key and value tensors.
|
// CACHE_T is the stored data type of kv-cache.
|
||||||
// KV_DTYPE is the real data type of kv-cache.
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
|
|||||||
@ -88,6 +88,48 @@ void reshape_and_cache_cpu_impl(
|
|||||||
}
|
}
|
||||||
}; // namespace
|
}; // namespace
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void concat_and_cache_mla_cpu_impl(
|
||||||
|
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||||
|
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||||
|
scalar_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||||
|
// + pe_dim)]
|
||||||
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
|
const int num_tokens, //
|
||||||
|
const int block_stride, //
|
||||||
|
const int entry_stride, //
|
||||||
|
const int kv_c_stride, //
|
||||||
|
const int k_pe_stride, //
|
||||||
|
const int kv_lora_rank, //
|
||||||
|
const int pe_dim, //
|
||||||
|
const int block_size //
|
||||||
|
) {
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||||
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
|
// NOTE: slot_idx can be -1 if the token is padded
|
||||||
|
if (slot_idx < 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int64_t block_idx = slot_idx / block_size;
|
||||||
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
|
|
||||||
|
auto copy = [&](const scalar_t* __restrict__ src,
|
||||||
|
scalar_t* __restrict__ dst, int src_stride, int dst_stride,
|
||||||
|
int size, int offset) {
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
const int64_t src_idx = token_idx * src_stride + i;
|
||||||
|
const int64_t dst_idx =
|
||||||
|
block_idx * block_stride + block_offset * entry_stride + i + offset;
|
||||||
|
dst[dst_idx] = src[src_idx];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||||
|
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Note: the key_caches and value_caches vectors are constant but
|
// Note: the key_caches and value_caches vectors are constant but
|
||||||
// not the Tensors they contain. The vectors need to be const refs
|
// not the Tensors they contain. The vectors need to be const refs
|
||||||
// in order to satisfy pytorch's C++ operator registration code.
|
// in order to satisfy pytorch's C++ operator registration code.
|
||||||
@ -134,6 +176,38 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void concat_and_cache_mla(
|
||||||
|
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||||
|
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||||
|
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||||
|
// pe_dim)]
|
||||||
|
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||||
|
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||||
|
int num_tokens = slot_mapping.size(0);
|
||||||
|
int kv_lora_rank = kv_c.size(1);
|
||||||
|
int pe_dim = k_pe.size(1);
|
||||||
|
int block_size = kv_cache.size(1);
|
||||||
|
|
||||||
|
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||||
|
TORCH_CHECK(kv_cache_dtype != "fp8");
|
||||||
|
|
||||||
|
int kv_c_stride = kv_c.stride(0);
|
||||||
|
int k_pe_stride = k_pe.stride(0);
|
||||||
|
int block_stride = kv_cache.stride(0);
|
||||||
|
int entry_stride = kv_cache.stride(1);
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
kv_c.scalar_type(), "concat_and_cache_mla_cpu_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(concat_and_cache_mla_cpu_impl)
|
||||||
|
concat_and_cache_mla_cpu_impl<scalar_t>(
|
||||||
|
kv_c.data_ptr<scalar_t>(), k_pe.data_ptr<scalar_t>(),
|
||||||
|
kv_cache.data_ptr<scalar_t>(), slot_mapping.data_ptr<int64_t>(),
|
||||||
|
num_tokens, block_stride, entry_stride, kv_c_stride, k_pe_stride,
|
||||||
|
kv_lora_rank, pe_dim, block_size);
|
||||||
|
CPU_KERNEL_GUARD_OUT(concat_and_cache_mla_cpu_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||||
const torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||||
|
|||||||
@ -78,9 +78,14 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
|||||||
|
|
||||||
__m256i reg;
|
__m256i reg;
|
||||||
|
|
||||||
|
// normal load
|
||||||
explicit FP16Vec16(const void* ptr)
|
explicit FP16Vec16(const void* ptr)
|
||||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||||
|
|
||||||
|
// non-temproal load
|
||||||
|
explicit FP16Vec16(bool, void* ptr)
|
||||||
|
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||||
|
|
||||||
explicit FP16Vec16(const FP32Vec16&);
|
explicit FP16Vec16(const FP32Vec16&);
|
||||||
|
|
||||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||||
@ -110,9 +115,14 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
|||||||
|
|
||||||
__m256i reg;
|
__m256i reg;
|
||||||
|
|
||||||
|
// normal load
|
||||||
explicit BF16Vec16(const void* ptr)
|
explicit BF16Vec16(const void* ptr)
|
||||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||||
|
|
||||||
|
// non-temproal load
|
||||||
|
explicit BF16Vec16(bool, void* ptr)
|
||||||
|
: reg(_mm256_stream_load_si256((__m256i*)ptr)) {}
|
||||||
|
|
||||||
explicit BF16Vec16(const FP32Vec16&);
|
explicit BF16Vec16(const FP32Vec16&);
|
||||||
|
|
||||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||||
@ -130,6 +140,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
|||||||
|
|
||||||
__m512i reg;
|
__m512i reg;
|
||||||
|
|
||||||
|
explicit BF16Vec32() : reg(_mm512_setzero_si512()) {}
|
||||||
|
|
||||||
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||||
|
|
||||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||||
@ -311,8 +323,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
|||||||
|
|
||||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||||
|
|
||||||
|
// normal load
|
||||||
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||||
|
|
||||||
|
// non-temproal load
|
||||||
|
explicit FP32Vec16(bool, void* ptr)
|
||||||
|
: reg((__m512)_mm512_stream_load_si512(ptr)) {}
|
||||||
|
|
||||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||||
|
|
||||||
explicit FP32Vec16(const FP32Vec4& data)
|
explicit FP32Vec16(const FP32Vec4& data)
|
||||||
@ -545,6 +562,33 @@ struct INT8Vec16 : public Vec<INT8Vec16> {
|
|||||||
_mm_mask_storeu_epi8(ptr, mask, reg);
|
_mm_mask_storeu_epi8(ptr, mask, reg);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct INT8Vec64 : public Vec<INT8Vec64> {
|
||||||
|
constexpr static int VEC_ELEM_NUM = 64;
|
||||||
|
union AliasReg {
|
||||||
|
__m512i reg;
|
||||||
|
int8_t values[VEC_ELEM_NUM];
|
||||||
|
};
|
||||||
|
|
||||||
|
__m512i reg;
|
||||||
|
|
||||||
|
// normal load
|
||||||
|
explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {}
|
||||||
|
|
||||||
|
// non-temproal load
|
||||||
|
explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {}
|
||||||
|
|
||||||
|
void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); }
|
||||||
|
|
||||||
|
void save(int8_t* ptr, const int elem_num) const {
|
||||||
|
constexpr uint64_t M = 0xFFFFFFFFFFFFFFFF;
|
||||||
|
__mmask64 mask = _cvtu64_mask64(M >> (64 - elem_num));
|
||||||
|
_mm512_mask_storeu_epi8(ptr, mask, reg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-temproal save
|
||||||
|
void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); }
|
||||||
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -655,6 +699,22 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
|||||||
|
|
||||||
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||||
|
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
inline void non_temporal_save(FP16Vec16& vec, void* ptr) {
|
||||||
|
_mm256_stream_si256((__m256i*)ptr, vec.reg);
|
||||||
|
}
|
||||||
|
inline void non_temporal_save(BF16Vec32& vec, void* ptr) {
|
||||||
|
_mm512_stream_si512((__m512i*)ptr, vec.reg);
|
||||||
|
}
|
||||||
|
inline void non_temporal_save(BF16Vec16& vec, void* ptr) {
|
||||||
|
_mm256_stream_si256((__m256i*)ptr, vec.reg);
|
||||||
|
}
|
||||||
|
inline void non_temporal_save(FP32Vec16& vec, void* ptr) {
|
||||||
|
_mm512_stream_ps((float*)ptr, vec.reg);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline void mem_barrier() { _mm_mfence(); }
|
||||||
}; // namespace vec_op
|
}; // namespace vec_op
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
393
csrc/cpu/mla_decode.cpp
Normal file
393
csrc/cpu/mla_decode.cpp
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
#include "cpu_types.hpp"
|
||||||
|
#include <float.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename scalar_t>
|
||||||
|
struct KernelVecType {
|
||||||
|
using qk_load_vec_type = void;
|
||||||
|
using qk_vec_type = void;
|
||||||
|
using v_load_vec_type = void;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<float> {
|
||||||
|
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||||
|
using qk_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::FP32Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::Half> {
|
||||||
|
#if defined(__powerpc64__) || defined(__s390x__)
|
||||||
|
// Power and s390x architecture-specific vector types
|
||||||
|
using qk_load_vec_type = vec_op::FP32Vec16;
|
||||||
|
using qk_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::FP32Vec16;
|
||||||
|
#else
|
||||||
|
// Fallback for other architectures, including x86
|
||||||
|
using qk_load_vec_type = vec_op::FP16Vec16;
|
||||||
|
using qk_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::FP16Vec16;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __AVX512BF16__
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
|
using qk_load_vec_type = vec_op::BF16Vec32;
|
||||||
|
using qk_vec_type = vec_op::BF16Vec32;
|
||||||
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
|
};
|
||||||
|
#elif defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
|
||||||
|
// pass
|
||||||
|
#else
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
|
using qk_load_vec_type = vec_op::BF16Vec16;
|
||||||
|
using qk_vec_type = vec_op::FP32Vec16;
|
||||||
|
using v_load_vec_type = vec_op::BF16Vec16;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE, int HEAD_UNROLL,
|
||||||
|
typename qk_vec_type>
|
||||||
|
void mla_decode_block_head(
|
||||||
|
const qk_vec_type* __restrict__ q_vecs, // [HEAD_UNROLL, head_dim]
|
||||||
|
const qk_vec_type* __restrict__ k_vecs, // [block_size, head_dim]
|
||||||
|
const vec_op::FP32Vec16* __restrict v_vecs_f32, // [block_size, v_head_dim]
|
||||||
|
float* __restrict__ acc_out, // [HEAD_UNROLL, v_head_dim]
|
||||||
|
float* __restrict__ acc_lse, // [HEAD_UNROLL]
|
||||||
|
const float scale, const int num_tokens) {
|
||||||
|
using f32_vec_type = vec_op::FP32Vec16;
|
||||||
|
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||||
|
constexpr int V_NUM_ELEM = f32_vec_type::VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
float logits[BLOCK_SIZE][HEAD_UNROLL] = {}; // initialize to zeros
|
||||||
|
float max_val[HEAD_UNROLL];
|
||||||
|
std::fill(max_val, max_val + HEAD_UNROLL, -FLT_MAX);
|
||||||
|
|
||||||
|
f32_vec_type acc_vec[BLOCK_SIZE][HEAD_UNROLL];
|
||||||
|
for (int i = 0; i < HEAD_DIM; i += QK_NUM_ELEM) {
|
||||||
|
// load to registers
|
||||||
|
qk_vec_type q_vec[HEAD_UNROLL];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||||
|
q_vec[unroll] =
|
||||||
|
qk_vec_type{q_vecs[(i + unroll * HEAD_DIM) / QK_NUM_ELEM]};
|
||||||
|
|
||||||
|
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||||
|
qk_vec_type k_vec(k_vecs[(block_offset * HEAD_DIM + i) / QK_NUM_ELEM]);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||||
|
vec_op::fma(acc_vec[block_offset][unroll], q_vec[unroll], k_vec);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||||
|
const float acc = acc_vec[block_offset][unroll].reduce_sum() * scale;
|
||||||
|
logits[block_offset][unroll] = acc;
|
||||||
|
max_val[unroll] = std::max(max_val[unroll], acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum_exp[HEAD_UNROLL] = {};
|
||||||
|
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||||
|
const float val =
|
||||||
|
std::exp(logits[block_offset][unroll] - max_val[unroll]);
|
||||||
|
logits[block_offset][unroll] = val;
|
||||||
|
sum_exp[unroll] += val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f32_vec_type this_out[V_HEAD_DIM / V_NUM_ELEM][HEAD_UNROLL];
|
||||||
|
|
||||||
|
for (int block_offset = 0; block_offset < num_tokens; ++block_offset) {
|
||||||
|
// load to registers
|
||||||
|
f32_vec_type scale_[HEAD_UNROLL];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||||
|
scale_[unroll] =
|
||||||
|
f32_vec_type{logits[block_offset][unroll] / sum_exp[unroll]};
|
||||||
|
|
||||||
|
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||||
|
f32_vec_type v_vec(
|
||||||
|
v_vecs_f32[(block_offset * HEAD_DIM + i) / V_NUM_ELEM]);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll)
|
||||||
|
vec_op::fma(this_out[i / V_NUM_ELEM][unroll], v_vec, scale_[unroll]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge attention state
|
||||||
|
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||||
|
f32_vec_type prev_scale[HEAD_UNROLL];
|
||||||
|
f32_vec_type curr_scale[HEAD_UNROLL];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||||
|
const float prev_lse = acc_lse[unroll];
|
||||||
|
const float curr_lse = std::log(sum_exp[unroll]) +
|
||||||
|
max_val[unroll]; // add back max_val to get true lse
|
||||||
|
// softmax trick
|
||||||
|
const float max_lse = std::max(prev_lse, curr_lse);
|
||||||
|
const float prev_sum_exp = std::exp(prev_lse - max_lse);
|
||||||
|
const float curr_sum_exp = std::exp(curr_lse - max_lse);
|
||||||
|
|
||||||
|
const float new_sum_exp = prev_sum_exp + curr_sum_exp;
|
||||||
|
acc_lse[unroll] = std::log(new_sum_exp) + max_lse;
|
||||||
|
|
||||||
|
prev_scale[unroll] = f32_vec_type{prev_sum_exp / new_sum_exp};
|
||||||
|
curr_scale[unroll] = f32_vec_type{curr_sum_exp / new_sum_exp};
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < V_HEAD_DIM; i += V_NUM_ELEM) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int unroll = 0; unroll < HEAD_UNROLL; ++unroll) {
|
||||||
|
f32_vec_type o_vec(acc_out + i + V_HEAD_DIM * unroll);
|
||||||
|
o_vec = o_vec * prev_scale[unroll] +
|
||||||
|
this_out[i / V_NUM_ELEM][unroll] * curr_scale[unroll];
|
||||||
|
o_vec.save(acc_out + i + V_HEAD_DIM * unroll);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
q_vecs += HEAD_DIM / QK_NUM_ELEM * HEAD_UNROLL;
|
||||||
|
acc_out += V_HEAD_DIM * HEAD_UNROLL;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE,
|
||||||
|
typename qk_vec_type>
|
||||||
|
void mla_decode_block(
|
||||||
|
const qk_vec_type* __restrict__ q_vecs, // [num_heads, head_dim]
|
||||||
|
const scalar_t* __restrict__ kv_cache, // [block_size, head_dim]
|
||||||
|
float* __restrict__ acc_out, // [num_heads, v_head_dim]
|
||||||
|
float* __restrict__ acc_lse, // [num_heads]
|
||||||
|
const int num_heads, const float scale, const int num_tokens) {
|
||||||
|
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||||
|
static_assert(
|
||||||
|
std::is_same<qk_vec_type,
|
||||||
|
typename KernelVecType<scalar_t>::qk_vec_type>::value);
|
||||||
|
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
|
||||||
|
using f32_vec_type = vec_op::FP32Vec16;
|
||||||
|
static_assert(qk_load_vec_type::VEC_ELEM_NUM == qk_vec_type::VEC_ELEM_NUM);
|
||||||
|
static_assert(v_load_vec_type::VEC_ELEM_NUM == f32_vec_type::VEC_ELEM_NUM);
|
||||||
|
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||||
|
constexpr int V_NUM_ELEM = v_load_vec_type::VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
const qk_vec_type* k_vecs;
|
||||||
|
const f32_vec_type* v_vecs_f32;
|
||||||
|
float* kv_cache_f32 = nullptr;
|
||||||
|
|
||||||
|
if constexpr (!std::is_same<scalar_t, float>::value) {
|
||||||
|
// convert KV cache block to FP32 to reuse it across query heads and
|
||||||
|
// attn @ V computation, since FP16/BF16->FP32 is expensive.
|
||||||
|
// TODO: move malloc outside of this fn to reuse across iterations.
|
||||||
|
const int nbytes = BLOCK_SIZE * HEAD_DIM * sizeof(float);
|
||||||
|
kv_cache_f32 = static_cast<float*>(std::aligned_alloc(64, nbytes));
|
||||||
|
|
||||||
|
for (int block_offset = 0; block_offset < num_tokens; ++block_offset)
|
||||||
|
for (int i = 0; i < HEAD_DIM; i += V_NUM_ELEM) {
|
||||||
|
v_load_vec_type kv_load_vec(kv_cache + block_offset * HEAD_DIM + i);
|
||||||
|
f32_vec_type kv_vec_f32(kv_load_vec);
|
||||||
|
kv_vec_f32.save(kv_cache_f32 + block_offset * HEAD_DIM + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (std::is_same<qk_load_vec_type, qk_vec_type>::value) {
|
||||||
|
// for AVX512_BF16, Q @ K.T uses BF16 for K (no conversion)
|
||||||
|
// NOTE: in this case, we only need to convert the V section to FP32.
|
||||||
|
// But for simplicity, we will convert the whole KV block to FP32.
|
||||||
|
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||||
|
} else {
|
||||||
|
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache_f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// attn @ V always use FP32 for V, since attn is FP32.
|
||||||
|
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache_f32);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// KV cache is FP32. don't need to do anything.
|
||||||
|
k_vecs = reinterpret_cast<const qk_vec_type*>(kv_cache);
|
||||||
|
v_vecs_f32 = reinterpret_cast<const f32_vec_type*>(kv_cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute 2 heads at the same time to improve ILP and
|
||||||
|
// take advantage of register cache for K and V.
|
||||||
|
constexpr int HEAD_UNROLL = 2;
|
||||||
|
for (int iter = 0; iter < num_heads / HEAD_UNROLL; ++iter) {
|
||||||
|
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, HEAD_UNROLL>(
|
||||||
|
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||||
|
|
||||||
|
q_vecs += HEAD_UNROLL * HEAD_DIM / QK_NUM_ELEM;
|
||||||
|
acc_out += HEAD_UNROLL * V_HEAD_DIM;
|
||||||
|
acc_lse += HEAD_UNROLL;
|
||||||
|
}
|
||||||
|
|
||||||
|
// take care of the remaining heads
|
||||||
|
for (int iter = 0; iter < num_heads % HEAD_UNROLL; ++iter) {
|
||||||
|
mla_decode_block_head<HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE, 1>(
|
||||||
|
q_vecs, k_vecs, v_vecs_f32, acc_out, acc_lse, scale, num_tokens);
|
||||||
|
|
||||||
|
q_vecs += HEAD_DIM / QK_NUM_ELEM;
|
||||||
|
acc_out += V_HEAD_DIM;
|
||||||
|
acc_lse += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kv_cache_f32 != nullptr) {
|
||||||
|
std::free(kv_cache_f32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
template <typename scalar_t, int HEAD_DIM, int V_HEAD_DIM, int BLOCK_SIZE>
|
||||||
|
void mla_decode_kvcache_cpu_impl(
|
||||||
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, v_head_dim]
|
||||||
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_dim]
|
||||||
|
const scalar_t* __restrict__ kv_cache, // [num_blocks, block_size,
|
||||||
|
// head_dim]
|
||||||
|
const int num_heads, const float scale,
|
||||||
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||||
|
const int* __restrict__ seq_lens, // [num_seqs]
|
||||||
|
const int max_num_blocks_per_seq, const int o_stride, const int q_stride,
|
||||||
|
const int kv_stride, const int num_seqs) {
|
||||||
|
using qk_load_vec_type = typename KernelVecType<scalar_t>::qk_load_vec_type;
|
||||||
|
using qk_vec_type = typename KernelVecType<scalar_t>::qk_vec_type;
|
||||||
|
constexpr int QK_NUM_ELEM = qk_vec_type::VEC_ELEM_NUM;
|
||||||
|
|
||||||
|
// shared across threads
|
||||||
|
const int max_threads = omp_get_max_threads();
|
||||||
|
const int acc_out_nbytes =
|
||||||
|
max_threads * num_heads * V_HEAD_DIM * sizeof(float);
|
||||||
|
float* acc_out = static_cast<float*>(std::aligned_alloc(64, acc_out_nbytes));
|
||||||
|
std::vector<float> acc_lse(max_threads * num_heads);
|
||||||
|
|
||||||
|
// allocate memory to pre-convert query to FP32 later
|
||||||
|
float* q_f32;
|
||||||
|
constexpr bool PRE_CONVERT_QUERY =
|
||||||
|
!std::is_same<scalar_t, float>::value &&
|
||||||
|
std::is_same<qk_vec_type, vec_op::FP32Vec16>::value;
|
||||||
|
if constexpr (PRE_CONVERT_QUERY) {
|
||||||
|
const int q_f32_nbytes = num_heads * HEAD_DIM * sizeof(float);
|
||||||
|
q_f32 = static_cast<float*>(std::aligned_alloc(64, q_f32_nbytes));
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
const int num_threads = omp_get_num_threads();
|
||||||
|
const int thread_id = omp_get_thread_num();
|
||||||
|
float* __restrict__ acc_out_thread =
|
||||||
|
acc_out + thread_id * num_heads * V_HEAD_DIM;
|
||||||
|
float* __restrict__ acc_lse_thread = acc_lse.data() + thread_id * num_heads;
|
||||||
|
|
||||||
|
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
|
||||||
|
// reset accumulator
|
||||||
|
std::fill(acc_out_thread, acc_out_thread + num_heads * V_HEAD_DIM, 0.0f);
|
||||||
|
std::fill(acc_lse_thread, acc_lse_thread + num_heads, -FLT_MAX);
|
||||||
|
|
||||||
|
const int seq_len = seq_lens[seq_idx];
|
||||||
|
const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
const int last_block_size = seq_len - (block_num - 1) * BLOCK_SIZE;
|
||||||
|
|
||||||
|
const qk_vec_type* q_vecs;
|
||||||
|
if constexpr (PRE_CONVERT_QUERY) {
|
||||||
|
// pre-convert query to FP32 since FP16/BF16->FP32 is slow.
|
||||||
|
#pragma omp for
|
||||||
|
for (int i = 0; i < num_heads * HEAD_DIM; i += QK_NUM_ELEM) {
|
||||||
|
qk_load_vec_type q_load_vec(q + seq_idx * q_stride + i);
|
||||||
|
qk_vec_type q_vec(q_load_vec);
|
||||||
|
q_vec.save(q_f32 + i);
|
||||||
|
}
|
||||||
|
q_vecs = reinterpret_cast<const qk_vec_type*>(q_f32);
|
||||||
|
} else {
|
||||||
|
q_vecs = reinterpret_cast<const qk_vec_type*>(q + seq_idx * q_stride);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma omp for
|
||||||
|
for (int block_idx = 0; block_idx < block_num; ++block_idx) {
|
||||||
|
const int physical_block_idx =
|
||||||
|
block_tables[seq_idx * max_num_blocks_per_seq + block_idx];
|
||||||
|
const int num_tokens =
|
||||||
|
block_idx < block_num - 1 ? BLOCK_SIZE : last_block_size;
|
||||||
|
|
||||||
|
mla_decode_block<scalar_t, HEAD_DIM, V_HEAD_DIM, BLOCK_SIZE>(
|
||||||
|
q_vecs, kv_cache + physical_block_idx * kv_stride, acc_out_thread,
|
||||||
|
acc_lse_thread, num_heads, scale, num_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge attention states across threads
|
||||||
|
// section 2.2 in https://arxiv.org/pdf/2501.01005
|
||||||
|
// each thread is responsible for 1 head
|
||||||
|
#pragma omp for
|
||||||
|
for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
|
||||||
|
float* acc_lse_head = acc_lse.data() + head_idx;
|
||||||
|
float* acc_out_head = acc_out + head_idx * V_HEAD_DIM;
|
||||||
|
|
||||||
|
float max_val = -FLT_MAX;
|
||||||
|
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||||
|
max_val = std::max(max_val, acc_lse_head[thread_id_ * num_heads]);
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum_exp = 0.0f;
|
||||||
|
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||||
|
float val = std::exp(acc_lse_head[thread_id_ * num_heads] - max_val);
|
||||||
|
acc_lse_head[thread_id_ * num_heads] = val;
|
||||||
|
sum_exp += val;
|
||||||
|
}
|
||||||
|
|
||||||
|
float inv_sum = 1.0f / sum_exp;
|
||||||
|
float out_head[V_HEAD_DIM] = {};
|
||||||
|
for (int thread_id_ = 0; thread_id_ < num_threads; ++thread_id_) {
|
||||||
|
float scale_ = acc_lse_head[thread_id_ * num_heads] * inv_sum;
|
||||||
|
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||||
|
out_head[i] +=
|
||||||
|
acc_out_head[thread_id_ * num_heads * V_HEAD_DIM + i] * scale_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < V_HEAD_DIM; ++i) {
|
||||||
|
vec_op::storeFP32(out_head[i], out + seq_idx * o_stride +
|
||||||
|
head_idx * V_HEAD_DIM + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (PRE_CONVERT_QUERY) {
|
||||||
|
std::free(q_f32);
|
||||||
|
}
|
||||||
|
std::free(acc_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||||
|
torch::Tensor& kv_cache, double scale,
|
||||||
|
torch::Tensor& block_tables, torch::Tensor& seq_lens) {
|
||||||
|
const int num_seqs = query.size(0);
|
||||||
|
const int num_heads = query.size(1);
|
||||||
|
const int head_dim = query.size(2);
|
||||||
|
const int block_size = kv_cache.size(1);
|
||||||
|
const int v_head_dim = out.size(2);
|
||||||
|
|
||||||
|
const int max_num_blocks_per_seq = block_tables.size(1);
|
||||||
|
const int o_stride = out.stride(0);
|
||||||
|
const int q_stride = query.stride(0);
|
||||||
|
const int kv_stride = kv_cache.stride(0);
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
query.scalar_type(), "mla_decode_kvcache_cpu_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(mla_decode_kvcache_cpu_impl)
|
||||||
|
if (head_dim == 576 && v_head_dim == 512 && block_size == 16)
|
||||||
|
mla_decode_kvcache_cpu_impl<scalar_t, 576, 512, 16>(
|
||||||
|
out.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
|
||||||
|
kv_cache.data_ptr<scalar_t>(), num_heads, scale,
|
||||||
|
block_tables.data_ptr<int>(), seq_lens.data_ptr<int>(),
|
||||||
|
max_num_blocks_per_seq, o_stride, q_stride, kv_stride, num_seqs);
|
||||||
|
else
|
||||||
|
TORCH_CHECK(false, "Unsupported block size: ", block_size);
|
||||||
|
CPU_KERNEL_GUARD_OUT(mla_decode_kvcache_cpu_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
781
csrc/cpu/shm.cpp
Normal file
781
csrc/cpu/shm.cpp
Normal file
@ -0,0 +1,781 @@
|
|||||||
|
#include "cpu/cpu_types.hpp"
|
||||||
|
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#include <sys/stat.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#define MAX_SHM_RANK_NUM 8
|
||||||
|
#define MAX_THREAD_NUM 12
|
||||||
|
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||||
|
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
||||||
|
#define MAX_P2P_SEND_TENSOR_NUM 8
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
struct KernelVecType {
|
||||||
|
using scalar_vec_t = void;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<float> {
|
||||||
|
using scalar_vec_t = vec_op::FP32Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::BFloat16> {
|
||||||
|
using scalar_vec_t = vec_op::BF16Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct KernelVecType<c10::Half> {
|
||||||
|
using scalar_vec_t = vec_op::FP16Vec16;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
||||||
|
|
||||||
|
struct ThreadSHMContext {
|
||||||
|
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
||||||
|
int thread_id;
|
||||||
|
int thread_num;
|
||||||
|
int rank;
|
||||||
|
int group_size;
|
||||||
|
size_t _spinning_count;
|
||||||
|
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||||
|
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||||
|
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||||
|
|
||||||
|
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||||
|
const int group_size, void* thread_shm_ptr)
|
||||||
|
: thread_id(thread_id),
|
||||||
|
thread_num(thread_num),
|
||||||
|
rank(rank),
|
||||||
|
group_size(group_size),
|
||||||
|
_spinning_count(0) {
|
||||||
|
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||||
|
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||||
|
TORCH_CHECK((size_t)this % 64 == 0);
|
||||||
|
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||||
|
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||||
|
shm_contexts[i] = nullptr;
|
||||||
|
thread_shm_ptrs[i] = nullptr;
|
||||||
|
swizzled_ranks[i] = (i + rank) % group_size;
|
||||||
|
thread_stats[i] = ThreadSHMStat::DONE;
|
||||||
|
}
|
||||||
|
set_context(rank, this, thread_shm_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
|
||||||
|
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
|
||||||
|
TORCH_CHECK(ptr);
|
||||||
|
TORCH_CHECK(thread_shm_ptr);
|
||||||
|
TORCH_CHECK_EQ(ptr->thread_num, thread_num);
|
||||||
|
TORCH_CHECK_EQ(ptr->thread_id, thread_id);
|
||||||
|
shm_contexts[rank] = ptr;
|
||||||
|
thread_shm_ptrs[rank] = thread_shm_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* get_thread_shm_ptr(int rank) {
|
||||||
|
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||||
|
|
||||||
|
void wait_for_all(ThreadSHMStat prev_stat) {
|
||||||
|
for (int idx = 0; idx < group_size; ++idx) {
|
||||||
|
int rank = get_swizzled_rank(idx);
|
||||||
|
while (thread_stats[rank] == prev_stat) {
|
||||||
|
++_spinning_count;
|
||||||
|
_mm_pause();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vec_op::mem_barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
||||||
|
while (thread_stats[rank] == prev_stat) {
|
||||||
|
++_spinning_count;
|
||||||
|
_mm_pause();
|
||||||
|
}
|
||||||
|
vec_op::mem_barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_thread_stat(ThreadSHMStat stat) {
|
||||||
|
for (int idx = 0; idx < group_size; ++idx) {
|
||||||
|
int rank = get_swizzled_rank(idx);
|
||||||
|
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
||||||
|
for (int idx = 0; idx < group_size; ++idx) {
|
||||||
|
int rank = get_swizzled_rank(idx);
|
||||||
|
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// barrier for all ranks in the group, used for all2all ops
|
||||||
|
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
||||||
|
void barrier(ThreadSHMStat next_stat) {
|
||||||
|
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
||||||
|
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
||||||
|
wait_for_all(ThreadSHMStat::DONE);
|
||||||
|
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
||||||
|
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
||||||
|
wait_for_all(ThreadSHMStat::THREAD_READY);
|
||||||
|
} else if (next_stat == ThreadSHMStat::DONE) {
|
||||||
|
set_thread_stat(ThreadSHMStat::DONE);
|
||||||
|
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string to_string() const {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "SHMContext:";
|
||||||
|
ss << "\nrank: " << rank;
|
||||||
|
ss << "\ngroup_size: " << group_size;
|
||||||
|
ss << "\nthread_num: " << thread_num;
|
||||||
|
ss << "\nthread_id: " << thread_id;
|
||||||
|
|
||||||
|
ss << "\nshm_ctx_stat_loop_seq: [";
|
||||||
|
for (int i = 0; i < group_size; ++i) {
|
||||||
|
ss << swizzled_ranks[i] << ", ";
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
|
||||||
|
ss << "\nshm_contexts: [";
|
||||||
|
for (int i = 0; i < group_size; ++i) {
|
||||||
|
if (shm_contexts[i]) {
|
||||||
|
ss << shm_contexts[i]->rank << ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ss << "]";
|
||||||
|
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SHMManager {
|
||||||
|
public:
|
||||||
|
explicit SHMManager(const std::string& name, const int rank,
|
||||||
|
const int group_size)
|
||||||
|
: _rank(rank),
|
||||||
|
_group_size(group_size),
|
||||||
|
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
||||||
|
_shm_names({""}),
|
||||||
|
_shared_mem_ptrs({nullptr}),
|
||||||
|
_shm_ctx(nullptr) {
|
||||||
|
_shm_names[rank] = get_shm_name(name, rank);
|
||||||
|
_shared_mem_ptrs[rank] = init_shm(rank);
|
||||||
|
_shm_ctx = reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank]);
|
||||||
|
|
||||||
|
for (int i = 0; i < _thread_num; ++i) {
|
||||||
|
ThreadSHMContext* ctx = new (_shm_ctx + i)
|
||||||
|
ThreadSHMContext(i, _thread_num, _rank, _group_size,
|
||||||
|
compute_thread_shm_ptr(_shm_ctx, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void join(const std::string& name) {
|
||||||
|
for (int rank_idx = 0; rank_idx < _group_size; ++rank_idx) {
|
||||||
|
if (rank_idx != _rank) {
|
||||||
|
TORCH_CHECK(_shm_names[rank_idx].empty());
|
||||||
|
TORCH_CHECK(_shared_mem_ptrs[rank_idx] == nullptr);
|
||||||
|
_shm_names[rank_idx] = get_shm_name(name, rank_idx);
|
||||||
|
_shared_mem_ptrs[rank_idx] = init_shm(rank_idx);
|
||||||
|
ThreadSHMContext* target_ctx =
|
||||||
|
reinterpret_cast<ThreadSHMContext*>(_shared_mem_ptrs[rank_idx]);
|
||||||
|
for (int thread_idx = 0; thread_idx < _thread_num; ++thread_idx) {
|
||||||
|
_shm_ctx[thread_idx].set_context(
|
||||||
|
rank_idx, target_ctx + thread_idx,
|
||||||
|
compute_thread_shm_ptr(target_ctx, thread_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~SHMManager() { destroy_shm(); }
|
||||||
|
|
||||||
|
ThreadSHMContext* get_shm_ctx() const { return _shm_ctx; }
|
||||||
|
|
||||||
|
static std::string get_shm_name(const std::string& name, int rank) {
|
||||||
|
return name + "_" + std::to_string(rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t create_singleton_instance(const std::string& name,
|
||||||
|
const int group_size,
|
||||||
|
const int rank) {
|
||||||
|
std::lock_guard<std::mutex> guard(SingletonInstancesLock);
|
||||||
|
SingletonInstances.emplace_back(
|
||||||
|
std::make_unique<SHMManager>(name, rank, group_size));
|
||||||
|
return static_cast<int64_t>(SingletonInstances.size() - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SHMManager* get_singleton_instance(int64_t handle) {
|
||||||
|
return SingletonInstances[handle].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static std::vector<std::unique_ptr<SHMManager>> SingletonInstances;
|
||||||
|
static std::mutex SingletonInstancesLock;
|
||||||
|
|
||||||
|
private:
|
||||||
|
static size_t round_to_alignment(size_t num) {
|
||||||
|
return ((num + 63) / 64) * 64;
|
||||||
|
}
|
||||||
|
|
||||||
|
int8_t* compute_thread_shm_ptr(ThreadSHMContext* ctx, int thread_id) {
|
||||||
|
int8_t* thread_shm_ptr =
|
||||||
|
reinterpret_cast<int8_t*>(ctx) +
|
||||||
|
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
||||||
|
return thread_shm_ptr +
|
||||||
|
thread_id * round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES);
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t compute_shm_size() {
|
||||||
|
const size_t rounded_rank_buffer_size =
|
||||||
|
round_to_alignment(PER_THREAD_SHM_BUFFER_BYTES) * _thread_num;
|
||||||
|
const size_t rounded_thread_shm_ctx_size =
|
||||||
|
round_to_alignment(_thread_num * sizeof(ThreadSHMContext));
|
||||||
|
const size_t shm_size =
|
||||||
|
rounded_thread_shm_ctx_size + rounded_rank_buffer_size;
|
||||||
|
return shm_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void* init_shm(int target_rank) {
|
||||||
|
const std::string& shm_name = _shm_names[target_rank];
|
||||||
|
const int local_rank = _rank;
|
||||||
|
const size_t shm_size = compute_shm_size();
|
||||||
|
|
||||||
|
int fd = -1;
|
||||||
|
if (local_rank == target_rank) {
|
||||||
|
fd = shm_open(shm_name.c_str(), O_CREAT | O_EXCL | O_RDWR,
|
||||||
|
S_IRUSR | S_IWUSR);
|
||||||
|
|
||||||
|
if (fd == -1)
|
||||||
|
TORCH_CHECK(false, "create shm in SHMManager failed. errno: " +
|
||||||
|
std::to_string(errno));
|
||||||
|
|
||||||
|
if (ftruncate(fd, shm_size) == -1)
|
||||||
|
TORCH_CHECK(false, "ftruncate in SHMManager failed. errno: " +
|
||||||
|
std::to_string(errno));
|
||||||
|
} else {
|
||||||
|
fd = shm_open(shm_name.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
|
||||||
|
|
||||||
|
if (fd == -1)
|
||||||
|
TORCH_CHECK(false, "open shm in SHMManager failed. errno: " +
|
||||||
|
std::to_string(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
void* shm_ptr = mmap(nullptr, shm_size, PROT_READ | PROT_WRITE,
|
||||||
|
MAP_SHARED | MAP_POPULATE, fd, 0);
|
||||||
|
|
||||||
|
if (shm_ptr == MAP_FAILED) {
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"mmap in SHMManager failed. errno: " + std::to_string(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (close(fd) != 0) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "close in SHMManager failed. errno: " + std::to_string(errno));
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK((size_t)shm_ptr % 64 == 0);
|
||||||
|
|
||||||
|
return shm_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void destroy_shm() {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "local rank " << _rank << ": [";
|
||||||
|
for (int thread_id = 0; thread_id < _thread_num; ++thread_id) {
|
||||||
|
ss << _shm_ctx[thread_id]._spinning_count << ", ";
|
||||||
|
}
|
||||||
|
ss << "]\n";
|
||||||
|
|
||||||
|
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||||
|
if (_shared_mem_ptrs[i] != nullptr) {
|
||||||
|
munmap(_shared_mem_ptrs[i], compute_shm_size());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!_shm_names[i].empty()) {
|
||||||
|
shm_unlink(_shm_names[i].c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int _rank;
|
||||||
|
int _group_size;
|
||||||
|
int _thread_num;
|
||||||
|
std::array<std::string, MAX_SHM_RANK_NUM> _shm_names;
|
||||||
|
std::array<void*, MAX_SHM_RANK_NUM> _shared_mem_ptrs;
|
||||||
|
ThreadSHMContext* _shm_ctx;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace shm_cc_ops {
|
||||||
|
template <typename scalar_t, typename F>
|
||||||
|
void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||||
|
int thread_num = ctx->thread_num;
|
||||||
|
int64_t total_bytes = elem_num * sizeof(scalar_t);
|
||||||
|
int64_t total_units_num =
|
||||||
|
(total_bytes + MIN_THREAD_PROCESS_SIZE - 1) / MIN_THREAD_PROCESS_SIZE;
|
||||||
|
int64_t per_thread_units_num =
|
||||||
|
(total_units_num + thread_num - 1) / thread_num;
|
||||||
|
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
||||||
|
int64_t max_per_thread_iteration_elem_num =
|
||||||
|
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
||||||
|
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
||||||
|
|
||||||
|
#pragma omp parallel for schedule(static, 1)
|
||||||
|
for (int i = 0; i < thread_num; ++i) {
|
||||||
|
int64_t offset = i * per_thread_elem_num;
|
||||||
|
int64_t end = std::min(elem_num, offset + per_thread_elem_num);
|
||||||
|
int64_t curr_elem_num =
|
||||||
|
std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||||
|
ThreadSHMContext* thread_ctx = ctx + i;
|
||||||
|
|
||||||
|
while (curr_elem_num > 0) {
|
||||||
|
inner_func(thread_ctx, offset, curr_elem_num);
|
||||||
|
|
||||||
|
offset += max_per_thread_iteration_elem_num;
|
||||||
|
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}; // namespace shm_cc_ops
|
||||||
|
|
||||||
|
namespace shm_cc_ops {
|
||||||
|
|
||||||
|
void memcpy_from_shm(void* dst, void* src, const int64_t bytes) {
|
||||||
|
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
||||||
|
int64_t i = 0;
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (; i < aligned_bytes; i += 64) {
|
||||||
|
vec_op::INT8Vec64 data(
|
||||||
|
true, (int8_t*)src + i); // stream loading shm to avoid caching
|
||||||
|
data.save((int8_t*)dst + i);
|
||||||
|
}
|
||||||
|
if (aligned_bytes < bytes) {
|
||||||
|
vec_op::INT8Vec64 data(true, (int8_t*)src + aligned_bytes);
|
||||||
|
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void memcpy_to_shm(void* dst, void* src, const int64_t bytes) {
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (int64_t i = 0; i < bytes; i += 64) {
|
||||||
|
vec_op::INT8Vec64 data((int8_t*)src + i);
|
||||||
|
data.nt_save((int8_t*)dst + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void memcpy(void* dst, void* src, const int64_t bytes) {
|
||||||
|
const int64_t aligned_bytes = ((bytes >> 6) << 6); // 64 bytes aligned
|
||||||
|
int64_t i = 0;
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (; i < aligned_bytes; i += 64) {
|
||||||
|
vec_op::INT8Vec64 data((int8_t*)src + i);
|
||||||
|
data.save((int8_t*)dst + i);
|
||||||
|
}
|
||||||
|
if (aligned_bytes < bytes) {
|
||||||
|
vec_op::INT8Vec64 data((int8_t*)src + aligned_bytes);
|
||||||
|
data.save((int8_t*)dst + aligned_bytes, bytes - aligned_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int RANKS>
|
||||||
|
void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||||
|
size_t elem_num) {
|
||||||
|
CPU_KERNEL_GUARD_IN(all_reduce_sum_impl)
|
||||||
|
using vec_t = typename KernelVecType<scalar_t>::scalar_vec_t;
|
||||||
|
constexpr int64_t vec_elem_num = vec_t::get_elem_num();
|
||||||
|
const int worldsize = ctx->group_size;
|
||||||
|
|
||||||
|
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||||
|
ctx, elem_num,
|
||||||
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||||
|
int64_t data_elem_num) {
|
||||||
|
int rank = thread_ctx->rank;
|
||||||
|
scalar_t* thread_shm_ptr =
|
||||||
|
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||||
|
scalar_t* thread_data_ptr = data + data_offset;
|
||||||
|
int64_t thread_data_elem_num = data_elem_num * sizeof(scalar_t);
|
||||||
|
|
||||||
|
scalar_t* remote_data_ptrs[RANKS - 1];
|
||||||
|
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||||
|
remote_data_ptrs[idx] = thread_ctx->get_thread_shm_ptr<scalar_t>(
|
||||||
|
thread_ctx->get_swizzled_rank(idx + 1));
|
||||||
|
});
|
||||||
|
|
||||||
|
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||||
|
|
||||||
|
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
||||||
|
thread_data_elem_num);
|
||||||
|
|
||||||
|
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||||
|
|
||||||
|
int64_t aligned_data_elem_num =
|
||||||
|
(data_elem_num / vec_elem_num) * vec_elem_num;
|
||||||
|
int64_t i = 0;
|
||||||
|
#pragma GCC unroll 4
|
||||||
|
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
||||||
|
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||||
|
vec_op::FP32Vec16 local_data_fp32(local_data);
|
||||||
|
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||||
|
vec_t remote_data(
|
||||||
|
true, remote_data_ptrs[idx] + i); // stream load from shm
|
||||||
|
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
||||||
|
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
||||||
|
});
|
||||||
|
vec_t reduced_data(local_data_fp32);
|
||||||
|
reduced_data.save(thread_data_ptr + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i < data_elem_num) {
|
||||||
|
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||||
|
vec_op::FP32Vec16 local_data_fp32(local_data);
|
||||||
|
vec_op::unroll_loop<int, RANKS - 1>([&](int idx) {
|
||||||
|
vec_t remote_data(
|
||||||
|
true, remote_data_ptrs[idx] + i); // stream load from shm
|
||||||
|
vec_op::FP32Vec16 remote_data_fp32(remote_data);
|
||||||
|
local_data_fp32 = local_data_fp32 + remote_data_fp32; // sum reduce
|
||||||
|
});
|
||||||
|
vec_t reduced_data(local_data_fp32);
|
||||||
|
reduced_data.save(thread_data_ptr + i,
|
||||||
|
data_elem_num - aligned_data_elem_num);
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||||
|
});
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}; // namespace shm_cc_ops
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<SHMManager>> SHMManager::SingletonInstances = {};
|
||||||
|
std::mutex SHMManager::SingletonInstancesLock = {};
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void shm_allreduce_sum(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num) {
|
||||||
|
switch (ctx->group_size) {
|
||||||
|
case 2:
|
||||||
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 2>(ctx, data, elem_num);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 3>(ctx, data, elem_num);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 4>(ctx, data, elem_num);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
shm_cc_ops::all_reduce_sum_impl<scalar_t, 8>(ctx, data, elem_num);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"Invalid world size: " + std::to_string(ctx->group_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||||
|
scalar_t** outputs, const int dst) {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
||||||
|
const int worldsize = ctx->group_size;
|
||||||
|
TORCH_CHECK_LT(dst, worldsize);
|
||||||
|
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||||
|
ctx, elem_num,
|
||||||
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||||
|
int64_t data_elem_num) {
|
||||||
|
int rank = thread_ctx->rank;
|
||||||
|
scalar_t* thread_shm_ptr =
|
||||||
|
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||||
|
|
||||||
|
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||||
|
|
||||||
|
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
||||||
|
data_elem_num * sizeof(scalar_t));
|
||||||
|
|
||||||
|
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||||
|
|
||||||
|
if (rank == dst) {
|
||||||
|
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
||||||
|
data_elem_num * sizeof(scalar_t));
|
||||||
|
for (int i = 1; i < worldsize; ++i) {
|
||||||
|
int src_rank = thread_ctx->get_swizzled_rank(i);
|
||||||
|
scalar_t* src_ptr =
|
||||||
|
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
||||||
|
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
||||||
|
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
||||||
|
data_elem_num * sizeof(scalar_t));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||||
|
});
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MemPiece {
|
||||||
|
void* ptr;
|
||||||
|
int64_t size;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
T* data_ptr() {
|
||||||
|
return reinterpret_cast<T*>(ptr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TensorListMeta {
|
||||||
|
int64_t tensor_bytes[MAX_P2P_SEND_TENSOR_NUM];
|
||||||
|
torch::ScalarType tensor_types[MAX_P2P_SEND_TENSOR_NUM];
|
||||||
|
int64_t tensor_num;
|
||||||
|
int64_t total_bytes;
|
||||||
|
|
||||||
|
TensorListMeta() : tensor_num(0), total_bytes(0) {
|
||||||
|
static_assert(sizeof(TensorListMeta) % 64 == 0);
|
||||||
|
static_assert(sizeof(TensorListMeta) <
|
||||||
|
MIN_THREAD_PROCESS_SIZE); // To ensure the metadata always
|
||||||
|
// hold by the thread 0
|
||||||
|
for (int i = 0; i < MAX_P2P_SEND_TENSOR_NUM; ++i) {
|
||||||
|
tensor_bytes[i] = 0;
|
||||||
|
tensor_ptrs[i] = nullptr;
|
||||||
|
tensor_types[i] = torch::ScalarType::Undefined;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For send and recv
|
||||||
|
void bind_tensor_list(std::vector<torch::Tensor>& tensor_list) {
|
||||||
|
TORCH_CHECK(tensor_types[0] == torch::ScalarType::Undefined,
|
||||||
|
"Re-bind TensorListMeta is not allowed.")
|
||||||
|
TORCH_CHECK_LE(tensor_list.size(), MAX_P2P_SEND_TENSOR_NUM);
|
||||||
|
tensor_num = tensor_list.size();
|
||||||
|
int64_t bytes_sum = 0;
|
||||||
|
for (int i = 0; i < tensor_list.size(); ++i) {
|
||||||
|
torch::Tensor& t = tensor_list[i];
|
||||||
|
TORCH_CHECK(t.is_contiguous());
|
||||||
|
tensor_bytes[i] = t.nbytes();
|
||||||
|
tensor_types[i] = t.scalar_type();
|
||||||
|
tensor_ptrs[i] = t.data_ptr();
|
||||||
|
bytes_sum += t.nbytes();
|
||||||
|
}
|
||||||
|
total_bytes = bytes_sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For recv
|
||||||
|
std::vector<torch::Tensor> generate_tensor_list() {
|
||||||
|
std::vector<torch::Tensor> tensor_list;
|
||||||
|
tensor_list.reserve(tensor_num);
|
||||||
|
|
||||||
|
for (int i = 0; i < tensor_num; ++i) {
|
||||||
|
int64_t bytes = tensor_bytes[i];
|
||||||
|
auto type = tensor_types[i];
|
||||||
|
int64_t elem_bytes = torch::elementSize(type);
|
||||||
|
|
||||||
|
TORCH_CHECK_EQ(bytes % elem_bytes, 0);
|
||||||
|
int64_t elem_num = bytes / elem_bytes;
|
||||||
|
auto options = torch::TensorOptions().dtype(type).device(torch::kCPU);
|
||||||
|
tensor_list.emplace_back(torch::empty({elem_num}, options));
|
||||||
|
}
|
||||||
|
return tensor_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
MemPiece get_data(int64_t offset) {
|
||||||
|
for (int i = 0; i < tensor_num; ++i) {
|
||||||
|
if (offset < tensor_bytes[i]) {
|
||||||
|
return {reinterpret_cast<int8_t*>(tensor_ptrs[i]) + offset,
|
||||||
|
tensor_bytes[i] - offset};
|
||||||
|
}
|
||||||
|
offset -= tensor_bytes[i];
|
||||||
|
}
|
||||||
|
return {nullptr, 0};
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void* tensor_ptrs[MAX_P2P_SEND_TENSOR_NUM];
|
||||||
|
int8_t _padding[40];
|
||||||
|
};
|
||||||
|
|
||||||
|
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||||
|
const std::vector<torch::Tensor>& tensor_list) {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
||||||
|
std::vector<torch::Tensor> tensor_list_with_metadata;
|
||||||
|
tensor_list_with_metadata.reserve(1 + tensor_list.size());
|
||||||
|
|
||||||
|
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
||||||
|
tensor_list_with_metadata.emplace_back(
|
||||||
|
torch::empty({sizeof(TensorListMeta)}, options));
|
||||||
|
tensor_list_with_metadata.insert(tensor_list_with_metadata.end(),
|
||||||
|
tensor_list.begin(), tensor_list.end());
|
||||||
|
|
||||||
|
torch::Tensor& metadata_tensor = tensor_list_with_metadata[0];
|
||||||
|
TORCH_CHECK_EQ(metadata_tensor.nbytes(), sizeof(TensorListMeta));
|
||||||
|
|
||||||
|
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
|
||||||
|
metadata->bind_tensor_list(tensor_list_with_metadata);
|
||||||
|
|
||||||
|
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||||
|
ctx, metadata->total_bytes,
|
||||||
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||||
|
int64_t data_elem_num) {
|
||||||
|
int rank = thread_ctx->rank;
|
||||||
|
// Wait until the receiver set the stat to DONE
|
||||||
|
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||||
|
|
||||||
|
int64_t curr_shm_offset = 0;
|
||||||
|
while (curr_shm_offset < data_elem_num) {
|
||||||
|
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
||||||
|
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||||
|
shm_cc_ops::memcpy(
|
||||||
|
thread_ctx->get_thread_shm_ptr<int8_t>(rank) + curr_shm_offset,
|
||||||
|
frag.ptr, frag.size);
|
||||||
|
curr_shm_offset += frag.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||||
|
int64_t src) {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list_impl)
|
||||||
|
auto options = torch::TensorOptions().dtype(torch::kInt8).device(torch::kCPU);
|
||||||
|
torch::Tensor metadata_tensor =
|
||||||
|
torch::empty({sizeof(TensorListMeta)}, options);
|
||||||
|
|
||||||
|
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
||||||
|
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||||
|
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||||
|
ctx->get_thread_shm_ptr<void>(src),
|
||||||
|
sizeof(TensorListMeta));
|
||||||
|
TensorListMeta* src_metadata =
|
||||||
|
reinterpret_cast<TensorListMeta*>(metadata_tensor.data_ptr());
|
||||||
|
std::vector<torch::Tensor> tensor_list_with_metadata =
|
||||||
|
src_metadata->generate_tensor_list();
|
||||||
|
|
||||||
|
TensorListMeta metadata;
|
||||||
|
metadata.bind_tensor_list(tensor_list_with_metadata);
|
||||||
|
TORCH_CHECK_EQ(metadata.tensor_num, src_metadata->tensor_num);
|
||||||
|
TORCH_CHECK_EQ(metadata.total_bytes, src_metadata->total_bytes);
|
||||||
|
|
||||||
|
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||||
|
ctx, metadata.total_bytes,
|
||||||
|
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||||
|
int64_t data_elem_num) {
|
||||||
|
// Wait until the sender set the stat to SHM_DATA_READY
|
||||||
|
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||||
|
int64_t curr_shm_offset = 0;
|
||||||
|
while (curr_shm_offset < data_elem_num) {
|
||||||
|
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||||
|
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||||
|
shm_cc_ops::memcpy(
|
||||||
|
frag.ptr,
|
||||||
|
thread_ctx->get_thread_shm_ptr<int8_t>(src) + curr_shm_offset,
|
||||||
|
frag.size);
|
||||||
|
curr_shm_offset += frag.size;
|
||||||
|
}
|
||||||
|
|
||||||
|
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
||||||
|
});
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> tensor_list;
|
||||||
|
tensor_list.reserve(metadata.tensor_num - 1);
|
||||||
|
tensor_list.insert(tensor_list.begin(), tensor_list_with_metadata.begin() + 1,
|
||||||
|
tensor_list_with_metadata.end());
|
||||||
|
|
||||||
|
return tensor_list;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void shm_gather(int64_t handle, torch::Tensor& data,
|
||||||
|
const std::optional<std::vector<torch::Tensor>>& outputs,
|
||||||
|
int64_t dst) {
|
||||||
|
TORCH_CHECK(data.is_contiguous())
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_gather_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_gather_impl)
|
||||||
|
|
||||||
|
if (outputs.has_value()) {
|
||||||
|
TORCH_CHECK_LE(outputs->size(), MAX_SHM_RANK_NUM);
|
||||||
|
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
||||||
|
for (int i = 0; i < outputs->size(); ++i) {
|
||||||
|
output_ptrs[i] = outputs->at(i).data_ptr<scalar_t>();
|
||||||
|
}
|
||||||
|
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||||
|
data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
||||||
|
dst);
|
||||||
|
} else {
|
||||||
|
shm_gather_impl(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||||
|
data.data_ptr<scalar_t>(), data.numel(), (scalar_t**)(0),
|
||||||
|
dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
CPU_KERNEL_GUARD_OUT(shm_gather_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void shm_all_gather(int64_t handle, const torch::Tensor& data,
|
||||||
|
torch::Tensor& output) {
|
||||||
|
TORCH_CHECK(data.is_contiguous())
|
||||||
|
TORCH_CHECK(output.is_contiguous())
|
||||||
|
|
||||||
|
const int64_t input_elem_num = data.numel();
|
||||||
|
const int64_t output_elem_num = output.numel();
|
||||||
|
TORCH_CHECK_EQ(output_elem_num % input_elem_num, 0);
|
||||||
|
const int world_size = output_elem_num / input_elem_num;
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_all_gather_impl", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_all_gather_impl)
|
||||||
|
auto ctx = SHMManager::get_singleton_instance(handle)->get_shm_ctx();
|
||||||
|
TORCH_CHECK_EQ(ctx->group_size, world_size);
|
||||||
|
|
||||||
|
scalar_t* output_ptrs[MAX_SHM_RANK_NUM] = {nullptr};
|
||||||
|
for (int i = 0; i < world_size; ++i) {
|
||||||
|
output_ptrs[i] = output.data_ptr<scalar_t>() + i * input_elem_num;
|
||||||
|
}
|
||||||
|
shm_gather_impl(ctx, data.data_ptr<scalar_t>(), data.numel(), output_ptrs,
|
||||||
|
ctx->rank);
|
||||||
|
CPU_KERNEL_GUARD_OUT(shm_all_gather_impl)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void shm_allreduce(int64_t handle, torch::Tensor& data) {
|
||||||
|
TORCH_CHECK(data.is_contiguous())
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(data.scalar_type(), "shm_allreduce_sum", [&] {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_allreduce_sum)
|
||||||
|
shm_allreduce_sum(SHMManager::get_singleton_instance(handle)->get_shm_ctx(),
|
||||||
|
data.data_ptr<scalar_t>(), data.numel());
|
||||||
|
CPU_KERNEL_GUARD_OUT(shm_allreduce_sum)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void shm_send_tensor_list(int64_t handle,
|
||||||
|
const std::vector<torch::Tensor>& tensor_list,
|
||||||
|
int64_t dst) {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
||||||
|
shm_send_tensor_list_impl(
|
||||||
|
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
||||||
|
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
|
||||||
|
CPU_KERNEL_GUARD_IN(shm_recv_tensor_list)
|
||||||
|
auto tensor_list = shm_recv_tensor_list_impl(
|
||||||
|
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), src);
|
||||||
|
CPU_KERNEL_GUARD_OUT(shm_recv_tensor_list)
|
||||||
|
return tensor_list;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
||||||
|
const int64_t rank) {
|
||||||
|
return SHMManager::create_singleton_instance(name, group_size, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string join_shm_manager(int64_t handle, const std::string& name) {
|
||||||
|
auto shm_manager = SHMManager::get_singleton_instance(handle);
|
||||||
|
TORCH_CHECK(shm_manager);
|
||||||
|
shm_manager->join(name);
|
||||||
|
return shm_manager->get_shm_ctx()->to_string();
|
||||||
|
}
|
||||||
@ -18,6 +18,30 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
|||||||
const std::optional<torch::Tensor>& azp,
|
const std::optional<torch::Tensor>& azp,
|
||||||
const std::optional<torch::Tensor>& bias);
|
const std::optional<torch::Tensor>& bias);
|
||||||
|
|
||||||
|
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||||
|
torch::Tensor& kv_cache, double scale,
|
||||||
|
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||||
|
|
||||||
|
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
|
||||||
|
const int64_t rank);
|
||||||
|
|
||||||
|
std::string join_shm_manager(int64_t handle, const std::string& name);
|
||||||
|
|
||||||
|
void shm_allreduce(int64_t handle, torch::Tensor& data);
|
||||||
|
|
||||||
|
void shm_gather(int64_t handle, torch::Tensor& data,
|
||||||
|
const std::optional<std::vector<torch::Tensor>>& outputs,
|
||||||
|
int64_t dst);
|
||||||
|
|
||||||
|
void shm_all_gather(int64_t handle, const torch::Tensor& data,
|
||||||
|
torch::Tensor& output);
|
||||||
|
|
||||||
|
void shm_send_tensor_list(int64_t handle,
|
||||||
|
const std::vector<torch::Tensor>& tensor_list,
|
||||||
|
int64_t dst);
|
||||||
|
|
||||||
|
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
// vLLM custom ops
|
// vLLM custom ops
|
||||||
|
|
||||||
@ -127,6 +151,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor? azp, Tensor? bias) -> ()");
|
" Tensor? azp, Tensor? bias) -> ()");
|
||||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// SHM CCL
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
|
||||||
|
&init_shm_manager);
|
||||||
|
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
|
||||||
|
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
|
||||||
|
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||||
|
ops.def(
|
||||||
|
"shm_gather(int handle, Tensor data, Tensor[](a!)? outputs, int dst) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("shm_gather", torch::kCPU, &shm_gather);
|
||||||
|
ops.def(
|
||||||
|
"shm_all_gather(int handle, Tensor data, Tensor! output) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("shm_all_gather", torch::kCPU, &shm_all_gather);
|
||||||
|
ops.def(
|
||||||
|
"shm_send_tensor_list(int handle, Tensor[](a) tensor_list, int dst) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("shm_send_tensor_list", torch::kCPU, &shm_send_tensor_list);
|
||||||
|
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||||
|
&shm_recv_tensor_list);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||||
@ -150,6 +197,14 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" str kv_cache_dtype,"
|
" str kv_cache_dtype,"
|
||||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||||
|
|
||||||
|
cache_ops.def(
|
||||||
|
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||||
|
" Tensor! kv_cache,"
|
||||||
|
" Tensor slot_mapping,"
|
||||||
|
" str kv_cache_dtype,"
|
||||||
|
" Tensor scale) -> ()");
|
||||||
|
cache_ops.impl("concat_and_cache_mla", torch::kCPU, &concat_and_cache_mla);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||||
@ -157,4 +212,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
|||||||
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) {
|
||||||
|
cpu_ops.def(
|
||||||
|
"mla_decode_kvcache("
|
||||||
|
" Tensor! out, Tensor query, Tensor kv_cache,"
|
||||||
|
" float scale, Tensor block_tables, Tensor seq_lens) -> ()");
|
||||||
|
cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache);
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||||
|
|||||||
@ -4,6 +4,11 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <sched.h>
|
#include <sched.h>
|
||||||
#endif
|
#endif
|
||||||
|
#if __GLIBC__ == 2 && __GLIBC_MINOR__ < 30
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <sys/syscall.h>
|
||||||
|
#define gettid() syscall(SYS_gettid)
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "cpu_types.hpp"
|
#include "cpu_types.hpp"
|
||||||
|
|
||||||
@ -18,7 +23,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
|||||||
|
|
||||||
#ifndef VLLM_NUMA_DISABLED
|
#ifndef VLLM_NUMA_DISABLED
|
||||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
bitmask* omp_cpu_mask = numa_parse_cpustring_all(cpu_ids.c_str());
|
||||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||||
std::vector<int> omp_cpu_ids;
|
std::vector<int> omp_cpu_ids;
|
||||||
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
||||||
|
|||||||
39
csrc/cuda_view.cu
Normal file
39
csrc/cuda_view.cu
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#include <torch/all.h>
|
||||||
|
#include <torch/cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
// This function assumes that `cpu_tensor` is a CPU tensor allocated with pinned
|
||||||
|
// memory, and that UVA (Unified Virtual Addressing) is enabled.
|
||||||
|
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor) {
|
||||||
|
TORCH_CHECK(cpu_tensor.device().is_cpu(), "Input tensor must be on CPU");
|
||||||
|
|
||||||
|
// Get raw host pointer from CPU tensor
|
||||||
|
void* host_ptr = cpu_tensor.data_ptr();
|
||||||
|
|
||||||
|
// Get a device pointer corresponding to the pinned host memory
|
||||||
|
void* device_ptr = nullptr;
|
||||||
|
cudaError_t err = cudaHostGetDevicePointer(&device_ptr, host_ptr, 0);
|
||||||
|
TORCH_CHECK(err == cudaSuccess,
|
||||||
|
"cudaHostGetDevicePointer failed: ", cudaGetErrorString(err));
|
||||||
|
|
||||||
|
// We'll use the same sizes, strides, and dtype as the CPU tensor.
|
||||||
|
// TODO: check if layout is respected.
|
||||||
|
auto sizes = cpu_tensor.sizes();
|
||||||
|
auto strides = cpu_tensor.strides();
|
||||||
|
auto options = cpu_tensor.options().device(torch::kCUDA);
|
||||||
|
|
||||||
|
// from_blob signature: from_blob(void *data, IntArrayRef sizes, ..., Deleter,
|
||||||
|
// const TensorOptions &) Provide a no-op deleter. The CPU tensor holds the
|
||||||
|
// memory, so we don't free it here.
|
||||||
|
auto deleter = [](void*) {
|
||||||
|
// no-op, since the memory is owned by the original CPU tensor
|
||||||
|
};
|
||||||
|
|
||||||
|
torch::Tensor cuda_tensor =
|
||||||
|
torch::from_blob(device_ptr, sizes, strides, deleter, options);
|
||||||
|
|
||||||
|
TORCH_CHECK(cuda_tensor.device().is_cuda(),
|
||||||
|
"Resulting tensor is not on CUDA device");
|
||||||
|
|
||||||
|
return cuda_tensor;
|
||||||
|
}
|
||||||
@ -12,7 +12,7 @@ static_assert(sizeof(void*) == sizeof(fptr_t));
|
|||||||
|
|
||||||
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
||||||
torch::Tensor& rank_data, int64_t rank,
|
torch::Tensor& rank_data, int64_t rank,
|
||||||
bool full_nvlink) {
|
bool fully_connected) {
|
||||||
int world_size = fake_ipc_ptrs.size();
|
int world_size = fake_ipc_ptrs.size();
|
||||||
if (world_size > 8)
|
if (world_size > 8)
|
||||||
throw std::invalid_argument("world size > 8 is not supported");
|
throw std::invalid_argument("world size > 8 is not supported");
|
||||||
@ -27,7 +27,7 @@ fptr_t init_custom_ar(const std::vector<fptr_t>& fake_ipc_ptrs,
|
|||||||
}
|
}
|
||||||
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
|
||||||
rank_data.numel(), rank, world_size,
|
rank_data.numel(), rank, world_size,
|
||||||
full_nvlink);
|
fully_connected);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -142,3 +142,48 @@ void register_graph_buffers(fptr_t _fa,
|
|||||||
bytes.reserve(handles.size());
|
bytes.reserve(handles.size());
|
||||||
fa->register_graph_buffers(bytes, offsets);
|
fa->register_graph_buffers(bytes, offsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::tuple<fptr_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||||
|
int64_t size) {
|
||||||
|
auto device_index = c10::cuda::current_device();
|
||||||
|
at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index));
|
||||||
|
void* buffer;
|
||||||
|
cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed;
|
||||||
|
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||||
|
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||||
|
|
||||||
|
// Allocate buffer
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
// data buffers need to be "uncached" for signal on MI200
|
||||||
|
AT_CUDA_CHECK(
|
||||||
|
hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached));
|
||||||
|
#else
|
||||||
|
AT_CUDA_CHECK(cudaMalloc((void**)&buffer, size));
|
||||||
|
#endif
|
||||||
|
AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream));
|
||||||
|
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||||
|
AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode));
|
||||||
|
|
||||||
|
// Create IPC memhandle for the allocated buffer.
|
||||||
|
// Will use it in open_mem_handle.
|
||||||
|
auto options =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||||
|
auto handle =
|
||||||
|
torch::empty({static_cast<int64_t>(sizeof(cudaIpcMemHandle_t))}, options);
|
||||||
|
AT_CUDA_CHECK(
|
||||||
|
cudaIpcGetMemHandle((cudaIpcMemHandle_t*)handle.data_ptr(), buffer));
|
||||||
|
|
||||||
|
return std::make_tuple(reinterpret_cast<fptr_t>(buffer), handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
fptr_t open_mem_handle(torch::Tensor& mem_handle) {
|
||||||
|
void* ipc_ptr;
|
||||||
|
AT_CUDA_CHECK(cudaIpcOpenMemHandle(
|
||||||
|
(void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)mem_handle.data_ptr()),
|
||||||
|
cudaIpcMemLazyEnablePeerAccess));
|
||||||
|
return reinterpret_cast<fptr_t>(ipc_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_shared_buffer(fptr_t buffer) {
|
||||||
|
AT_CUDA_CHECK(cudaFree(reinterpret_cast<void*>(buffer)));
|
||||||
|
}
|
||||||
|
|||||||
@ -5,6 +5,10 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
typedef __hip_bfloat16 nv_bfloat16;
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
@ -12,6 +16,7 @@
|
|||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
#define CUDACHECK(cmd) \
|
#define CUDACHECK(cmd) \
|
||||||
do { \
|
do { \
|
||||||
cudaError_t e = cmd; \
|
cudaError_t e = cmd; \
|
||||||
@ -22,24 +27,37 @@
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
namespace vllm {
|
// Maximal number of blocks in allreduce kernel.
|
||||||
|
|
||||||
constexpr int kMaxBlocks = 36;
|
constexpr int kMaxBlocks = 36;
|
||||||
|
|
||||||
|
// Default number of blocks in allreduce kernel.
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
const int defaultBlockLimit = 36;
|
||||||
|
CUpointer_attribute rangeStartAddrAttr = CU_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
||||||
|
#else
|
||||||
|
const int defaultBlockLimit = 16;
|
||||||
|
hipPointer_attribute rangeStartAddrAttr =
|
||||||
|
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Counter may overflow, but it's fine since unsigned int overflow is
|
// Counter may overflow, but it's fine since unsigned int overflow is
|
||||||
// well-defined behavior.
|
// well-defined behavior.
|
||||||
using FlagType = uint32_t;
|
using FlagType = uint32_t;
|
||||||
|
|
||||||
|
// Two sets of peer counters are needed for two syncs: starting and ending an
|
||||||
|
// operation. The reason is that it's possible for peer GPU block to arrive at
|
||||||
|
// the second sync point while the current GPU block haven't passed the first
|
||||||
|
// sync point. Thus, peer GPU may write counter+1 while current GPU is busy
|
||||||
|
// waiting for counter. We use alternating counter array to avoid this
|
||||||
|
// possibility.
|
||||||
struct Signal {
|
struct Signal {
|
||||||
alignas(128) FlagType self_counter[kMaxBlocks][8];
|
alignas(128) FlagType start[kMaxBlocks][8];
|
||||||
// Two sets of peer counters are needed for two syncs. The reason is that
|
alignas(128) FlagType end[kMaxBlocks][8];
|
||||||
// it's possible for peer GPU block to arrive at the second sync point while
|
alignas(128) FlagType _flag[kMaxBlocks]; // incremental flags for each rank
|
||||||
// the current GPU block haven't passed the first sync point. Thus, peer GPU
|
|
||||||
// may write counter+1 while current GPU is busy waiting for counter. We use
|
|
||||||
// alternating counter array to avoid this possibility.
|
|
||||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct __align__(16) RankData {
|
struct __align__(16) RankData {
|
||||||
const void* __restrict__ ptrs[8];
|
const void* ptrs[8];
|
||||||
};
|
};
|
||||||
|
|
||||||
struct __align__(16) RankSignals {
|
struct __align__(16) RankSignals {
|
||||||
@ -134,27 +152,29 @@ DINLINE O downcast(array_t<float, O::size> val) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
|
|
||||||
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||||
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag),
|
||||||
"l"(flag_addr));
|
"l"(flag_addr));
|
||||||
#else
|
#else
|
||||||
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
|
asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag),
|
||||||
"l"(flag_addr));
|
"l"(flag_addr));
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) {
|
||||||
FlagType flag;
|
FlagType flag;
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||||
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
asm volatile("ld.acquire.sys.global.u32 %0, [%1];"
|
||||||
: "=r"(flag)
|
: "=r"(flag)
|
||||||
: "l"(flag_addr));
|
: "l"(flag_addr));
|
||||||
#else
|
#else
|
||||||
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
|
asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;"
|
||||||
: "=r"(flag)
|
: "=r"(flag)
|
||||||
: "l"(flag_addr));
|
: "l"(flag_addr));
|
||||||
#endif
|
#endif
|
||||||
return flag;
|
return flag;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,37 +190,99 @@ static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) {
|
|||||||
return flag;
|
return flag;
|
||||||
}
|
}
|
||||||
|
|
||||||
// is_start: whether this is the very first synchronization barrier.
|
// This function is meant to be used as the first synchronization in the all
|
||||||
// need_fence: whether a memory fence is needed. If true, a release-acquire
|
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||||
// semantic is used to enforce memory access order before and after this
|
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||||
// barrier.
|
// other volatile writes.
|
||||||
template <int ngpus, bool is_start, bool need_fence = false>
|
template <int ngpus>
|
||||||
DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
|
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
||||||
int rank) {
|
int rank) {
|
||||||
if constexpr (!is_start) __syncthreads();
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||||
static_assert(
|
|
||||||
!(is_start && need_fence)); // Start barrier shouldn't need fence.
|
|
||||||
if (threadIdx.x < ngpus) {
|
if (threadIdx.x < ngpus) {
|
||||||
// Increment the counter. Technically we only need one counter, but we use
|
auto peer_counter_ptr = &sg.signals[threadIdx.x]->start[blockIdx.x][rank];
|
||||||
// multiple per block to eliminate the need to share the counter via smem.
|
auto self_counter_ptr = &self_sg->start[blockIdx.x][threadIdx.x];
|
||||||
auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1;
|
// Write the expected counter value to peer and wait for correct value
|
||||||
|
// from peer.
|
||||||
|
st_flag_volatile(peer_counter_ptr, flag);
|
||||||
|
while (ld_flag_volatile(self_counter_ptr) != flag);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
// use one thread to update flag
|
||||||
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function is meant to be used as the second or the final
|
||||||
|
// synchronization barrier in the all reduce kernel. If it's the final
|
||||||
|
// synchronization barrier, we don't need to make any visibility guarantees
|
||||||
|
// for prior memory accesses.
|
||||||
|
template <int ngpus, bool final_sync = false>
|
||||||
|
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||||
|
__syncthreads();
|
||||||
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||||
|
if (threadIdx.x < ngpus) {
|
||||||
|
auto peer_counter_ptr = &sg.signals[threadIdx.x]->end[blockIdx.x][rank];
|
||||||
|
auto self_counter_ptr = &self_sg->end[blockIdx.x][threadIdx.x];
|
||||||
// Write the expected counter value to peer and wait for correct value from
|
// Write the expected counter value to peer and wait for correct value from
|
||||||
// peer.
|
// peer.
|
||||||
auto peer_counter_ptr =
|
if constexpr (!final_sync) {
|
||||||
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
|
st_flag_release(peer_counter_ptr, flag);
|
||||||
auto self_counter_ptr =
|
while (ld_flag_acquire(self_counter_ptr) != flag);
|
||||||
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
|
|
||||||
if constexpr (need_fence) {
|
|
||||||
st_flag_release(peer_counter_ptr, val);
|
|
||||||
while (ld_flag_acquire(self_counter_ptr) != val);
|
|
||||||
} else {
|
} else {
|
||||||
st_flag_volatile(peer_counter_ptr, val);
|
st_flag_volatile(peer_counter_ptr, flag);
|
||||||
while (ld_flag_volatile(self_counter_ptr) != val);
|
while (ld_flag_volatile(self_counter_ptr) != flag);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if constexpr (is_start || need_fence) __syncthreads();
|
if constexpr (!final_sync) __syncthreads();
|
||||||
|
|
||||||
|
// use one thread to update flag
|
||||||
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
template <int ngpus>
|
||||||
|
DINLINE void barrier_at_start(const RankSignals& sg, Signal* self_sg,
|
||||||
|
int rank) {
|
||||||
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||||
|
if (threadIdx.x < ngpus) {
|
||||||
|
// simultaneously write to the corresponding flag of all ranks.
|
||||||
|
// Latency = 1 p2p write
|
||||||
|
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank],
|
||||||
|
flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
||||||
|
// wait until we got true from all ranks
|
||||||
|
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x],
|
||||||
|
__ATOMIC_RELAXED,
|
||||||
|
__MEMORY_SCOPE_DEVICE) < flag);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
// use one thread to update flag
|
||||||
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int ngpus, bool final_sync = false>
|
||||||
|
DINLINE void barrier_at_end(const RankSignals& sg, Signal* self_sg, int rank) {
|
||||||
|
__syncthreads();
|
||||||
|
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
|
||||||
|
if (threadIdx.x < ngpus) {
|
||||||
|
// simultaneously write to the corresponding flag of all ranks.
|
||||||
|
// Latency = 1 p2p write
|
||||||
|
__scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||||
|
flag,
|
||||||
|
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||||
|
__MEMORY_SCOPE_SYSTEM);
|
||||||
|
// wait until we got true from all ranks
|
||||||
|
while (
|
||||||
|
__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
|
||||||
|
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||||
|
__MEMORY_SCOPE_DEVICE) < flag);
|
||||||
|
}
|
||||||
|
if constexpr (!final_sync) __syncthreads();
|
||||||
|
// use one thread to update flag
|
||||||
|
if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename P, int ngpus, typename A>
|
template <typename P, int ngpus, typename A>
|
||||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
DINLINE P packed_reduce(const P* ptrs[], int idx) {
|
||||||
A tmp = upcast(ptrs[0][idx]);
|
A tmp = upcast(ptrs[0][idx]);
|
||||||
@ -220,13 +302,13 @@ __global__ void __launch_bounds__(512, 1)
|
|||||||
// note: we don't reorder the address so the accumulation order is the same
|
// note: we don't reorder the address so the accumulation order is the same
|
||||||
// for all ranks, ensuring bitwise identical results
|
// for all ranks, ensuring bitwise identical results
|
||||||
auto dp = *_dp;
|
auto dp = *_dp;
|
||||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
barrier_at_start<ngpus>(sg, self_sg, rank);
|
||||||
// do the actual reduction
|
// do the actual reduction
|
||||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||||
idx += gridDim.x * blockDim.x) {
|
idx += gridDim.x * blockDim.x) {
|
||||||
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
|
||||||
}
|
}
|
||||||
multi_gpu_barrier<ngpus, false>(sg, self_sg, rank);
|
barrier_at_end<ngpus, true>(sg, self_sg, rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename P>
|
template <typename P>
|
||||||
@ -255,18 +337,20 @@ __global__ void __launch_bounds__(512, 1)
|
|||||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||||
}
|
}
|
||||||
auto tmp_out = tmps[0];
|
auto tmp_out = tmps[0];
|
||||||
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
|
barrier_at_start<ngpus>(sg, self_sg, rank);
|
||||||
|
|
||||||
// stage 1: reduce scatter
|
// stage 1: reduce scatter
|
||||||
for (int idx = start + tid; idx < end; idx += stride) {
|
for (int idx = start + tid; idx < end; idx += stride) {
|
||||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||||
}
|
}
|
||||||
multi_gpu_barrier<ngpus, false, true>(sg, self_sg, rank);
|
barrier_at_end<ngpus>(sg, self_sg, rank);
|
||||||
|
|
||||||
// stage 2: allgather. Note: it's important to match the tid between
|
// stage 2: allgather. Note: it's important to match the tid between
|
||||||
// the two stages, because visibility across devices is only guaranteed
|
// the two stages, because visibility across devices is only guaranteed
|
||||||
// between threads that have the same tid. If thread i computes the sum of
|
// between threads that have the same tid. If thread i computes the sum of
|
||||||
// start + i in the first stage, then thread i also gathers start + i from all
|
// start + i in the first stage, then thread i also gathers start + i from
|
||||||
// ranks.
|
// all ranks.
|
||||||
|
|
||||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < ngpus; i++) {
|
for (int i = 0; i < ngpus; i++) {
|
||||||
@ -287,21 +371,22 @@ class CustomAllreduce {
|
|||||||
public:
|
public:
|
||||||
int rank_;
|
int rank_;
|
||||||
int world_size_;
|
int world_size_;
|
||||||
bool full_nvlink_;
|
// Full NVLink or xGMI connection between GPUs.
|
||||||
|
bool fully_connected_;
|
||||||
|
|
||||||
RankSignals sg_;
|
RankSignals sg_;
|
||||||
// Stores an map from a pointer to its peer pointters from all ranks.
|
// Stores a map from a pointer to its peer pointers from all ranks.
|
||||||
std::unordered_map<void*, RankData*> buffers_;
|
std::unordered_map<void*, RankData*> buffers_;
|
||||||
Signal* self_sg_;
|
Signal* self_sg_;
|
||||||
|
|
||||||
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
// Stores rank data from all ranks. This is mainly for cuda graph purposes.
|
||||||
// For cuda graph to work, all kernel arguments must be fixed during graph
|
// For cuda graph to work, all kernel arguments must be fixed during graph
|
||||||
// capture time. However, the peer pointers are not known during graph capture
|
// capture time. However, the peer pointers are not known during graph
|
||||||
// time. Therefore, during capture, we increment the rank data pointer and use
|
// capture time. Therefore, during capture, we increment the rank data
|
||||||
// that as the argument to the kernel. The kernel arguments are stored in
|
// pointer and use that as the argument to the kernel. The kernel arguments
|
||||||
// graph_unreg_buffers_. The actual peer pointers will be filled in at the
|
// are stored in graph_unreg_buffers_. The actual peer pointers will be
|
||||||
// memory pointed to by the pointers in graph_unreg_buffers_ when
|
// filled in at the memory pointed to by the pointers in
|
||||||
// the IPC handles are exchanged between ranks.
|
// graph_unreg_buffers_ when the IPC handles are exchanged between ranks.
|
||||||
//
|
//
|
||||||
// The overall process looks like this:
|
// The overall process looks like this:
|
||||||
// 1. Graph capture.
|
// 1. Graph capture.
|
||||||
@ -319,17 +404,18 @@ class CustomAllreduce {
|
|||||||
* Signals are an array of ipc-enabled buffers from all ranks.
|
* Signals are an array of ipc-enabled buffers from all ranks.
|
||||||
* For each of the buffer, the layout is as follows:
|
* For each of the buffer, the layout is as follows:
|
||||||
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
* | -- sizeof(Signal) -- | ------ a few MB ----- |
|
||||||
* The first section is for allreduce synchronization, and the second section
|
* The first section is for allreduce synchronization, and the second
|
||||||
* is for storing the intermediate results required by some allreduce algos.
|
* section is for storing the intermediate results required by some
|
||||||
|
* allreduce algos.
|
||||||
*
|
*
|
||||||
* Note: this class does not own any device memory. Any required buffers
|
* Note: this class does not own any device memory. Any required buffers
|
||||||
* are passed in from the constructor.
|
* are passed in from the constructor.
|
||||||
*/
|
*/
|
||||||
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
|
||||||
int rank, int world_size, bool full_nvlink = true)
|
int rank, int world_size, bool fully_connected = true)
|
||||||
: rank_(rank),
|
: rank_(rank),
|
||||||
world_size_(world_size),
|
world_size_(world_size),
|
||||||
full_nvlink_(full_nvlink),
|
fully_connected_(fully_connected),
|
||||||
self_sg_(signals[rank]),
|
self_sg_(signals[rank]),
|
||||||
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
|
||||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||||
@ -361,8 +447,7 @@ class CustomAllreduce {
|
|||||||
void* base_ptr;
|
void* base_ptr;
|
||||||
// note: must share the base address of each allocation, or we get wrong
|
// note: must share the base address of each allocation, or we get wrong
|
||||||
// address
|
// address
|
||||||
if (cuPointerGetAttribute(&base_ptr,
|
if (cuPointerGetAttribute(&base_ptr, rangeStartAddrAttr,
|
||||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
|
||||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||||
throw std::runtime_error("failed to get pointer attr");
|
throw std::runtime_error("failed to get pointer attr");
|
||||||
CUDACHECK(cudaIpcGetMemHandle(
|
CUDACHECK(cudaIpcGetMemHandle(
|
||||||
@ -396,11 +481,11 @@ class CustomAllreduce {
|
|||||||
|
|
||||||
// Note: when registering graph buffers, we intentionally choose to not
|
// Note: when registering graph buffers, we intentionally choose to not
|
||||||
// deduplicate the addresses. That means if the allocator reuses some
|
// deduplicate the addresses. That means if the allocator reuses some
|
||||||
// addresses, they will be registered again. This is to account for the remote
|
// addresses, they will be registered again. This is to account for the
|
||||||
// possibility of different allocation patterns between ranks. For example,
|
// remote possibility of different allocation patterns between ranks. For
|
||||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
// example, rank 1 may get the same input address for the second allreduce,
|
||||||
// got a different address. IPC handles have internal reference counting
|
// but rank 2 got a different address. IPC handles have internal reference
|
||||||
// mechanism so overhead should be small.
|
// counting mechanism so overhead should be small.
|
||||||
void register_graph_buffers(
|
void register_graph_buffers(
|
||||||
const std::vector<std::string>& handles,
|
const std::vector<std::string>& handles,
|
||||||
const std::vector<std::vector<int64_t>>& offsets) {
|
const std::vector<std::vector<int64_t>>& offsets) {
|
||||||
@ -431,15 +516,15 @@ class CustomAllreduce {
|
|||||||
/**
|
/**
|
||||||
* Performs allreduce, assuming input has already been registered.
|
* Performs allreduce, assuming input has already been registered.
|
||||||
*
|
*
|
||||||
* Block and grid default configs are results after careful grid search. Using
|
* Block and grid default configs are results after careful grid search.
|
||||||
* 36 blocks give the best or close to the best runtime on the devices I
|
* Using 36 blocks give the best or close to the best runtime on the devices
|
||||||
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
|
* I tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also
|
||||||
* take a small amount of SMs. Not quite sure the underlying reason, but my
|
* only take a small amount of SMs. Not quite sure the underlying reason,
|
||||||
* guess is that too many SMs will cause contention on NVLink bus.
|
* but my guess is that too many SMs will cause contention on NVLink bus.
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
void allreduce(cudaStream_t stream, T* input, T* output, int size,
|
||||||
int threads = 512, int block_limit = 36) {
|
int threads = 512, int block_limit = defaultBlockLimit) {
|
||||||
auto d = packed_t<T>::P::size;
|
auto d = packed_t<T>::P::size;
|
||||||
if (size % d != 0)
|
if (size % d != 0)
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
@ -473,13 +558,11 @@ class CustomAllreduce {
|
|||||||
#define KL(ngpus, name) \
|
#define KL(ngpus, name) \
|
||||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||||
rank_, size);
|
rank_, size);
|
||||||
// TODO(hanzhi713): Threshold is different for A100 and H100.
|
|
||||||
// Add per device threshold.
|
|
||||||
#define REDUCE_CASE(ngpus) \
|
#define REDUCE_CASE(ngpus) \
|
||||||
case ngpus: { \
|
case ngpus: { \
|
||||||
if (world_size_ == 2) { \
|
if (world_size_ == 2) { \
|
||||||
KL(ngpus, cross_device_reduce_1stage); \
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
} else if (full_nvlink_) { \
|
} else if (fully_connected_) { \
|
||||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||||
KL(ngpus, cross_device_reduce_1stage); \
|
KL(ngpus, cross_device_reduce_1stage); \
|
||||||
@ -497,7 +580,8 @@ class CustomAllreduce {
|
|||||||
REDUCE_CASE(8)
|
REDUCE_CASE(8)
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
"custom allreduce only supports num gpus in (2,4,6,8). Actual "
|
||||||
|
"num "
|
||||||
"gpus = " +
|
"gpus = " +
|
||||||
std::to_string(world_size_));
|
std::to_string(world_size_));
|
||||||
}
|
}
|
||||||
@ -511,10 +595,11 @@ class CustomAllreduce {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
* To inspect PTX/SASS, copy paste this header file to compiler explorer and
|
||||||
a template instantiation:
|
add a template instantiation:
|
||||||
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||||
half *, int, int, int);
|
half *, int, int, int);
|
||||||
*/
|
*/
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
@ -1,9 +1,9 @@
|
|||||||
/**
|
/**
|
||||||
* This is a standalone test for custom allreduce.
|
* This is a standalone test for custom allreduce.
|
||||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||||
* export MPI_HOME=xxx
|
* export MPI_HOME=XXX
|
||||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||||
* custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi
|
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||||
*
|
*
|
||||||
* Warning: this C++ test is not designed to be very readable and was used
|
* Warning: this C++ test is not designed to be very readable and was used
|
||||||
* during the rapid prototyping process.
|
* during the rapid prototyping process.
|
||||||
@ -22,7 +22,15 @@
|
|||||||
#include "cuda_profiler_api.h"
|
#include "cuda_profiler_api.h"
|
||||||
#include "custom_all_reduce.cuh"
|
#include "custom_all_reduce.cuh"
|
||||||
#include "mpi.h"
|
#include "mpi.h"
|
||||||
#include "nccl.h"
|
#ifdef USE_ROCM
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
typedef __hip_bfloat16 nv_bfloat16;
|
||||||
|
#include "rccl/rccl.h"
|
||||||
|
#include "custom_all_reduce_hip.cuh"
|
||||||
|
#else
|
||||||
|
#include "nccl.h"
|
||||||
|
#include "custom_all_reduce.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
#define MPICHECK(cmd) \
|
#define MPICHECK(cmd) \
|
||||||
do { \
|
do { \
|
||||||
@ -43,16 +51,29 @@
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
__global__ void dummy_kernel() {
|
__global__ void dummy_kernel() {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
for (int i = 0; i < 100; i++) {
|
||||||
|
uint64_t start = wall_clock64();
|
||||||
|
uint64_t cycles_elapsed;
|
||||||
|
do {
|
||||||
|
cycles_elapsed = wall_clock64() - start;
|
||||||
|
} while (cycles_elapsed < 100);
|
||||||
|
}
|
||||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
|
__global__ void dummy_kernel() {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
|
||||||
|
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||||
|
#else
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
long long int start = clock64();
|
long long int start = clock64();
|
||||||
while (clock64() - start < 150000000); // approximately 98.4ms on P40
|
while (clock64() - start < 150000000); // approximately 98.4ms on P40
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void set_data(T* data, int size, int myRank) {
|
__global__ void set_data(T* data, int size, int myRank) {
|
||||||
@ -121,8 +142,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
|
|||||||
* registration, they are allocated and registered together in the test for
|
* registration, they are allocated and registered together in the test for
|
||||||
* convenience.
|
* convenience.
|
||||||
*/
|
*/
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
CUDACHECK(hipExtMallocWithFlags(
|
||||||
|
(void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal),
|
||||||
|
hipDeviceMallocUncached));
|
||||||
|
#else
|
||||||
CUDACHECK(
|
CUDACHECK(
|
||||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||||
|
#endif
|
||||||
CUDACHECK(
|
CUDACHECK(
|
||||||
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||||
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||||
@ -311,13 +338,18 @@ int main(int argc, char** argv) {
|
|||||||
|
|
||||||
bool performance_test = true;
|
bool performance_test = true;
|
||||||
cudaProfilerStart();
|
cudaProfilerStart();
|
||||||
// Uncomment to scan through different block size configs.
|
// Uncomment to scan through different block size configs.
|
||||||
// for (int threads : {256, 512, 1024}) {
|
// for (int threads : {256, 512, 1024}) {
|
||||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
|
// run<half>(myRank, nRanks, comm, threads, block_limit, 1024 * 1024,
|
||||||
// performance_test);
|
// performance_test);
|
||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
const int block_limit = 16;
|
||||||
|
#else
|
||||||
|
const int block_limit = 36;
|
||||||
|
#endif
|
||||||
// Scan through different sizes to test performance.
|
// Scan through different sizes to test performance.
|
||||||
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
||||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
||||||
@ -326,4 +358,4 @@ int main(int argc, char** argv) {
|
|||||||
cudaProfilerStop();
|
cudaProfilerStop();
|
||||||
MPICHECK(MPI_Finalize());
|
MPICHECK(MPI_Finalize());
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
@ -48,4 +48,14 @@ struct enable_sm90_or_later : Kernel {
|
|||||||
Kernel::operator()(std::forward<Args>(args)...);
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Kernel>
|
||||||
|
struct enable_sm90_only : Kernel {
|
||||||
|
template <typename... Args>
|
||||||
|
CUTLASS_DEVICE void operator()(Args&&... args) {
|
||||||
|
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
|
||||||
|
Kernel::operator()(std::forward<Args>(args)...);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
@ -0,0 +1,457 @@
|
|||||||
|
/***************************************************************************************************
|
||||||
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||||
|
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
*
|
||||||
|
* Redistribution and use in source and binary forms, with or without
|
||||||
|
* modification, are permitted provided that the following conditions are met:
|
||||||
|
*
|
||||||
|
* 1. Redistributions of source code must retain the above copyright notice,
|
||||||
|
*this list of conditions and the following disclaimer.
|
||||||
|
*
|
||||||
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
* this list of conditions and the following disclaimer in the documentation
|
||||||
|
* and/or other materials provided with the distribution.
|
||||||
|
*
|
||||||
|
* 3. Neither the name of the copyright holder nor the names of its
|
||||||
|
* contributors may be used to endorse or promote products derived from
|
||||||
|
* this software without specific prior written permission.
|
||||||
|
*
|
||||||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||||
|
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||||
|
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||||
|
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||||
|
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||||
|
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||||
|
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||||
|
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||||
|
*POSSIBILITY OF SUCH DAMAGE.
|
||||||
|
*
|
||||||
|
**************************************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// This file is a modified excerpt of
|
||||||
|
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
||||||
|
// from https://github.com/NVIDIA/cutlass v3.5.0
|
||||||
|
// It has been modified to support either row/column or scalar broadcasting
|
||||||
|
// where the tensor being loaded from is always passed in via a device pointer.
|
||||||
|
// This lets one compiled kernel handle all cases of per-tensor or
|
||||||
|
// per-channel/per-token quantization.
|
||||||
|
//
|
||||||
|
// This interface also allows the scales to be passed in as tensors that
|
||||||
|
// consistently reside on the device, which avoids an issue with a previous
|
||||||
|
// implementation where scalars needed to be on the CPU since they
|
||||||
|
// were passed in via float values. This created a potential performance hazard
|
||||||
|
// if scales were initially on the device, and caused torch.compile graphs
|
||||||
|
// breaks when moving scales to the CPU.
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Turn off clang-format for the entire file to keep it close to upstream
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "cutlass/arch/barrier.h"
|
||||||
|
|
||||||
|
#include "cute/tensor.hpp"
|
||||||
|
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
||||||
|
|
||||||
|
namespace cutlass::epilogue::fusion {
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
using namespace detail;
|
||||||
|
|
||||||
|
// Row vector broadcast
|
||||||
|
template<
|
||||||
|
int Stages,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL = Stride<_0,_1,_0>,
|
||||||
|
int Alignment = 128 / sizeof_bits_v<Element>
|
||||||
|
>
|
||||||
|
struct Sm90RowOrScalarBroadcastArray {
|
||||||
|
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||||
|
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||||
|
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||||
|
|
||||||
|
struct SharedStorage {
|
||||||
|
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||||
|
};
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||||
|
// scalar that must be broadcast, instead of containing a scalar that is
|
||||||
|
// valid if ptr_row is null.
|
||||||
|
struct Arguments {
|
||||||
|
const Element* const* ptr_row_array = nullptr;
|
||||||
|
bool row_broadcast = true;
|
||||||
|
StrideMNL dRow = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static bool
|
||||||
|
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static cutlass::Status
|
||||||
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90RowOrScalarBroadcastArray() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90RowOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params(params)
|
||||||
|
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||||
|
|
||||||
|
Params params;
|
||||||
|
Element *smem = nullptr;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_producer_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_C_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_zero() const {
|
||||||
|
return (!params.row_broadcast && *(params.ptr_row_array[group]) == Element(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||||
|
return EmptyProducerLoadCallbacks{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||||
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ConsumerStoreCallbacks(
|
||||||
|
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||||
|
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||||
|
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||||
|
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_,
|
||||||
|
int group, Params const& params_)
|
||||||
|
: tGS_gRow(tGS_gRow_)
|
||||||
|
, tGS_sRow(tGS_sRow_)
|
||||||
|
, tGS_cRow(tGS_cRow_)
|
||||||
|
, tiled_G2S(tiled_g2s_)
|
||||||
|
, tSR_sRow(tSR_sRow_)
|
||||||
|
, tSR_rRow(tSR_rRow_)
|
||||||
|
, tCcRow(tCcRow_)
|
||||||
|
, residue_tCcRow(residue_tCcRow_)
|
||||||
|
, group(group)
|
||||||
|
, params(params_) {}
|
||||||
|
|
||||||
|
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||||
|
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||||
|
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||||
|
Tiled_G2S tiled_G2S;
|
||||||
|
|
||||||
|
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
|
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
ThrResidue residue_tCcRow; // (m, n)
|
||||||
|
ThrNum thr_num;
|
||||||
|
int group;
|
||||||
|
Params const& params;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin() {
|
||||||
|
if (!params.row_broadcast) {
|
||||||
|
fill(tSR_rRow, *(params.ptr_row_array[group]));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||||
|
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||||
|
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||||
|
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||||
|
|
||||||
|
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||||
|
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||||
|
continue; // OOB of SMEM,
|
||||||
|
}
|
||||||
|
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||||
|
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin_loop(int epi_m, int epi_n) {
|
||||||
|
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||||
|
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||||
|
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||||
|
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||||
|
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
|
Array<Element, FragmentSize> frg_row;
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
|
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return frg_row;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||||
|
class... Args
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||||
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
|
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||||
|
using ThreadCount = decltype(size(args.tiled_copy));
|
||||||
|
|
||||||
|
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow);
|
||||||
|
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||||
|
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||||
|
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||||
|
//// G2S: Gmem to Smem
|
||||||
|
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||||
|
Layout< Shape<_1, ThreadCount>,
|
||||||
|
Stride<_0, _1>>{},
|
||||||
|
Layout<_1>{});
|
||||||
|
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||||
|
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||||
|
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||||
|
|
||||||
|
//// G2S: Coord
|
||||||
|
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||||
|
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||||
|
|
||||||
|
//// S2R: Smem to Reg
|
||||||
|
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||||
|
|
||||||
|
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||||
|
tGS_gRow,
|
||||||
|
tGS_sRow,
|
||||||
|
tGS_cRow, tiled_g2s,
|
||||||
|
tSR_sRow,
|
||||||
|
tSR_rRow,
|
||||||
|
args.tCcD,
|
||||||
|
args.residue_cD,
|
||||||
|
ThreadCount{},
|
||||||
|
l,
|
||||||
|
params);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
// Column vector broadcast
|
||||||
|
template<
|
||||||
|
int Stages,
|
||||||
|
class CtaTileShapeMNK,
|
||||||
|
class Element,
|
||||||
|
class StrideMNL = Stride<_1,_0,_0>,
|
||||||
|
int Alignment = 128 / sizeof_bits_v<Element>
|
||||||
|
>
|
||||||
|
struct Sm90ColOrScalarBroadcastArray {
|
||||||
|
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
||||||
|
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||||
|
static_assert(
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
||||||
|
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
||||||
|
|
||||||
|
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
||||||
|
struct SharedStorage { };
|
||||||
|
|
||||||
|
// This struct has been modified to have a bool indicating that ptr_col is a
|
||||||
|
// scalar that must be broadcast, instead of containing a scalar that is
|
||||||
|
// valid if ptr_col is null.
|
||||||
|
struct Arguments {
|
||||||
|
const Element* const* ptr_col_array = nullptr;
|
||||||
|
bool col_broadcast = true;
|
||||||
|
StrideMNL dCol = {};
|
||||||
|
};
|
||||||
|
|
||||||
|
using Params = Arguments;
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static constexpr Params
|
||||||
|
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static bool
|
||||||
|
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static size_t
|
||||||
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class ProblemShape>
|
||||||
|
static cutlass::Status
|
||||||
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
||||||
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
||||||
|
return cutlass::Status::kSuccess;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_producer_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_C_load_needed() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_DEVICE bool
|
||||||
|
is_zero() const {
|
||||||
|
return (!params.col_broadcast && *(params.ptr_col_array[group]) == Element(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90ColOrScalarBroadcastArray() { }
|
||||||
|
|
||||||
|
CUTLASS_HOST_DEVICE
|
||||||
|
Sm90ColOrScalarBroadcastArray(Params const& params, SharedStorage const& shared_storage)
|
||||||
|
: params(params) { }
|
||||||
|
|
||||||
|
Params params;
|
||||||
|
|
||||||
|
template <class... Args>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||||
|
return EmptyProducerLoadCallbacks{};
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||||
|
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||||
|
CUTLASS_DEVICE
|
||||||
|
ConsumerStoreCallbacks(
|
||||||
|
GTensor&& tCgCol,
|
||||||
|
RTensor&& tCrCol,
|
||||||
|
CTensor&& tCcCol,
|
||||||
|
ProblemShape problem_shape,
|
||||||
|
int group,
|
||||||
|
Params const& params
|
||||||
|
):
|
||||||
|
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||||
|
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||||
|
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||||
|
m(get<0>(problem_shape)),
|
||||||
|
group(group),
|
||||||
|
params(params) {}
|
||||||
|
|
||||||
|
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
RTensor tCrCol;
|
||||||
|
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
Params const& params;
|
||||||
|
int m;
|
||||||
|
int group;
|
||||||
|
|
||||||
|
CUTLASS_DEVICE void
|
||||||
|
begin() {
|
||||||
|
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < size(pred); ++i) {
|
||||||
|
pred(i) = get<0>(tCcCol(i)) < m;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!params.col_broadcast) {
|
||||||
|
fill(tCrCol, *(params.ptr_col_array[group]));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter so we don't issue redundant copies over stride-0 modes
|
||||||
|
// (only works if 0-strides are in same location, which is by construction)
|
||||||
|
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementAccumulator, int FragmentSize>
|
||||||
|
CUTLASS_DEVICE Array<Element, FragmentSize>
|
||||||
|
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
||||||
|
Array<Element, FragmentSize> frg_col;
|
||||||
|
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
||||||
|
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int i = 0; i < FragmentSize; ++i) {
|
||||||
|
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return frg_col;
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template <
|
||||||
|
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
||||||
|
class... Args
|
||||||
|
>
|
||||||
|
CUTLASS_DEVICE auto
|
||||||
|
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||||
|
|
||||||
|
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||||
|
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||||
|
|
||||||
|
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol);
|
||||||
|
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
|
||||||
|
// Generate an identity tensor matching the shape of the global tensor and
|
||||||
|
// partition the same way, this will be used to generate the predicate
|
||||||
|
// tensor for loading
|
||||||
|
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||||
|
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||||
|
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||||
|
|
||||||
|
return ConsumerStoreCallbacks(
|
||||||
|
cute::move(tCgCol),
|
||||||
|
cute::move(tCrCol),
|
||||||
|
cute::move(tCcCol),
|
||||||
|
args.problem_shape_mnkl,
|
||||||
|
l,
|
||||||
|
params
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
@ -1,6 +1,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
||||||
|
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
This file defines custom epilogues for fusing channel scales, token scales,
|
This file defines custom epilogues for fusing channel scales, token scales,
|
||||||
@ -69,6 +70,16 @@ struct ScaledEpilogueBase {
|
|||||||
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
0 /*Stages*/, TileShape, T, T, Stride<Int<0>, Int<1>, Int<0>>,
|
||||||
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
128 / sizeof_bits_v<T>, EnableNullPtr>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ColOrScalarLoadArray =
|
||||||
|
cutlass::epilogue::fusion::Sm90ColOrScalarBroadcastArray<
|
||||||
|
0 /*Stages*/, TileShape, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using RowOrScalarLoadArray =
|
||||||
|
cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
|
||||||
|
0 /*Stages*/, TileShape, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||||
|
|
||||||
// This utility function constructs the arguments for the load descriptors
|
// This utility function constructs the arguments for the load descriptors
|
||||||
// from a tensor. It can handle both row and column, as well as row/column or
|
// from a tensor. It can handle both row and column, as well as row/column or
|
||||||
// scalar cases.
|
// scalar cases.
|
||||||
@ -96,6 +107,14 @@ struct ScaledEpilogueBase {
|
|||||||
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
||||||
return Arguments{data_ptr};
|
return Arguments{data_ptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Descriptor, typename T>
|
||||||
|
static auto args_from_tensor(const T* const* data_ptr, bool do_broadcast) {
|
||||||
|
using Arguments = typename Descriptor::Arguments;
|
||||||
|
static_assert(std::is_same_v<Descriptor, ColOrScalarLoadArray<T>> ||
|
||||||
|
std::is_same_v<Descriptor, RowOrScalarLoadArray<T>>);
|
||||||
|
return Arguments{data_ptr, do_broadcast};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -381,4 +400,51 @@ struct ScaledEpilogueBiasAzpToken
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
This epilogue works like ScaledEpilogue, but ScaleA and ScaleB are pointers
|
||||||
|
to arrays containing different scales used in group gemm. The number of
|
||||||
|
pointers in ScaleA and the number of pointers in ScaleB are equal to the
|
||||||
|
group size.
|
||||||
|
*/
|
||||||
|
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
||||||
|
struct ScaledEpilogueArray
|
||||||
|
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
||||||
|
private:
|
||||||
|
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
||||||
|
using Accum = typename SUPER::Accum;
|
||||||
|
using ScaleA = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||||
|
using ScaleB = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||||
|
|
||||||
|
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, float, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
using EVTCompute0 =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||||
|
|
||||||
|
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||||
|
cutlass::multiplies, ElementD, float,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using EVTCompute =
|
||||||
|
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
||||||
|
using ArgumentType = typename EVTCompute::Arguments;
|
||||||
|
|
||||||
|
using ScaleAArray = typename SUPER::template ColOrScalarLoadArray<float>;
|
||||||
|
using ScaleBArray = typename SUPER::template RowOrScalarLoadArray<float>;
|
||||||
|
|
||||||
|
static ArgumentType prepare_args(float const* const* a_scales_ptr,
|
||||||
|
float const* const* b_scales_ptr,
|
||||||
|
bool a_col_broadcast, bool b_row_broadcast) {
|
||||||
|
auto a_args = SUPER::template args_from_tensor<ScaleAArray, float>(
|
||||||
|
a_scales_ptr, a_col_broadcast);
|
||||||
|
auto b_args = SUPER::template args_from_tensor<ScaleBArray, float>(
|
||||||
|
b_scales_ptr, b_row_broadcast);
|
||||||
|
|
||||||
|
typename EVTCompute0::Arguments evt0_args{b_args, {}, {}};
|
||||||
|
return ArgumentType{a_args, evt0_args, {}};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
}; // namespace vllm::c3x
|
}; // namespace vllm::c3x
|
||||||
|
|||||||
@ -422,7 +422,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
|||||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||||
// in case the final state is separated between the last "smem_exchange" and
|
// in case the final state is separated between the last "smem_exchange" and
|
||||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||||
// (which occurs when `final_state_position` is a non-positivie index)
|
// (which occurs when `final_state_position` is a non-positive index)
|
||||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||||
input_t vals_load[kNElts] = {0};
|
input_t vals_load[kNElts] = {0};
|
||||||
|
|||||||
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
103
csrc/moe/marlin_moe_wna16/generate_kernels.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import glob
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
|
FILE_HEAD = """
|
||||||
|
// auto generated by generate.py
|
||||||
|
// clang-format off
|
||||||
|
|
||||||
|
#include "kernel.h"
|
||||||
|
#include "marlin_template.h"
|
||||||
|
|
||||||
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
TEMPLATE = ("template __global__ void Marlin<"
|
||||||
|
"{{scalar_t}}, "
|
||||||
|
"{{w_type_id}}, "
|
||||||
|
"{{threads}}, "
|
||||||
|
"{{thread_m_blocks}}, "
|
||||||
|
"{{thread_n_blocks}}, "
|
||||||
|
"{{thread_k_blocks}}, "
|
||||||
|
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||||
|
"{{stages}}, "
|
||||||
|
"{{'true' if has_act_order else 'false'}}, "
|
||||||
|
"{{'true' if has_zp else 'false'}}, "
|
||||||
|
"{{group_blocks}}, "
|
||||||
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
|
"( MARLIN_KERNEL_PARAMS );")
|
||||||
|
|
||||||
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
|
# we don't add it to reduce wheel size.
|
||||||
|
SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"]
|
||||||
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||||
|
|
||||||
|
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||||
|
# group_blocks:
|
||||||
|
# = 0 : act order case
|
||||||
|
# = -1 : channelwise quantization
|
||||||
|
# > 0 : group_size=16*group_blocks
|
||||||
|
GROUP_BLOCKS = [0, -1, 2, 4, 8]
|
||||||
|
DTYPES = ["fp16", "bf16"]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_old_kernels():
|
||||||
|
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
|
||||||
|
subprocess.call(["rm", "-f", filename])
|
||||||
|
|
||||||
|
|
||||||
|
def generate_new_kernels():
|
||||||
|
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
|
||||||
|
has_zp = "B" not in scalar_type
|
||||||
|
all_template_str_list = []
|
||||||
|
|
||||||
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
||||||
|
|
||||||
|
has_act_order = group_blocks == 0
|
||||||
|
if has_zp and has_act_order:
|
||||||
|
continue
|
||||||
|
if thread_configs[2] == 256:
|
||||||
|
if m_blocks <= 1 and thread_configs[0] != 128:
|
||||||
|
continue
|
||||||
|
if m_blocks > 1 and thread_configs[0] != 64:
|
||||||
|
continue
|
||||||
|
|
||||||
|
k_blocks = thread_configs[0] // 16
|
||||||
|
n_blocks = thread_configs[1] // 16
|
||||||
|
threads = thread_configs[2]
|
||||||
|
|
||||||
|
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||||
|
|
||||||
|
template_str = jinja2.Template(TEMPLATE).render(
|
||||||
|
scalar_t=c_dtype,
|
||||||
|
w_type_id=scalar_type + ".id()",
|
||||||
|
threads=threads,
|
||||||
|
thread_m_blocks=max(m_blocks, 1),
|
||||||
|
thread_n_blocks=n_blocks,
|
||||||
|
thread_k_blocks=k_blocks,
|
||||||
|
m_block_size_8=m_blocks == 0.5,
|
||||||
|
stages="pipe_stages",
|
||||||
|
has_act_order=has_act_order,
|
||||||
|
has_zp=has_zp,
|
||||||
|
group_blocks=group_blocks,
|
||||||
|
is_zp_float=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_template_str_list.append(template_str)
|
||||||
|
|
||||||
|
file_content = FILE_HEAD + "\n\n"
|
||||||
|
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
|
||||||
|
filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu"
|
||||||
|
|
||||||
|
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
remove_old_kernels()
|
||||||
|
generate_new_kernels()
|
||||||
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
44
csrc/moe/marlin_moe_wna16/kernel.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
|
||||||
|
#ifndef MARLIN_NAMESPACE_NAME
|
||||||
|
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "quantization/gptq_marlin/marlin.cuh"
|
||||||
|
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
|
||||||
|
#define MARLIN_KERNEL_PARAMS \
|
||||||
|
const int4 *__restrict__ A, const int4 *__restrict__ B, \
|
||||||
|
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
|
||||||
|
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
|
||||||
|
const int *__restrict__ g_idx, \
|
||||||
|
const int32_t *__restrict__ sorted_token_ids_ptr, \
|
||||||
|
const int32_t *__restrict__ expert_ids_ptr, \
|
||||||
|
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
|
||||||
|
const float *__restrict__ topk_weights_ptr, int top_k, \
|
||||||
|
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
|
||||||
|
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
|
||||||
|
bool use_fp32_reduce
|
||||||
|
|
||||||
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
|
const int threads, // number of threads in a threadblock
|
||||||
|
const int thread_m_blocks, // number of 16x16 blocks in the m
|
||||||
|
// dimension (batchsize) of the
|
||||||
|
// threadblock
|
||||||
|
const int thread_n_blocks, // same for n dimension (output)
|
||||||
|
const int thread_k_blocks, // same for k dimension (reduction)
|
||||||
|
const bool m_block_size_8, // whether m_block_size == 8
|
||||||
|
// only works when thread_m_blocks == 1
|
||||||
|
const int stages, // number of stages for the async global->shared
|
||||||
|
// fetch pipeline
|
||||||
|
const bool has_act_order, // whether act_order is enabled
|
||||||
|
const bool has_zp, // whether zero-points are enabled
|
||||||
|
const int group_blocks, // number of consecutive 16x16 blocks
|
||||||
|
// with a separate quantization scale
|
||||||
|
const bool is_zp_float // is zero point of float16 type?
|
||||||
|
>
|
||||||
|
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
|
||||||
|
|
||||||
|
}
|
||||||
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
1917
csrc/moe/marlin_moe_wna16/marlin_template.h
Normal file
File diff suppressed because it is too large
Load Diff
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
927
csrc/moe/marlin_moe_wna16/ops.cu
Normal file
@ -0,0 +1,927 @@
|
|||||||
|
/*
|
||||||
|
* Modified by Neural Magic
|
||||||
|
* Copyright (C) Marlin.2024 Elias Frantar
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Adapted from https://github.com/IST-DASLab/marlin
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MARLIN_NAMESPACE_NAME
|
||||||
|
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "kernel.h"
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
|
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||||
|
static_assert(std::is_same<scalar_t, half>::value || \
|
||||||
|
std::is_same<scalar_t, nv_bfloat16>::value, \
|
||||||
|
"only float16 and bfloat16 is supported");
|
||||||
|
|
||||||
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
|
|
||||||
|
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
|
||||||
|
|
||||||
|
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
|
|
||||||
|
template <int moe_block_size>
|
||||||
|
__global__ void permute_cols_kernel(
|
||||||
|
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||||
|
int4* __restrict__ out_int4_ptr,
|
||||||
|
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||||
|
const int32_t* __restrict__ expert_ids_ptr,
|
||||||
|
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||||
|
int size_k, int top_k) {};
|
||||||
|
|
||||||
|
} // namespace marlin
|
||||||
|
|
||||||
|
torch::Tensor moe_wna16_marlin_gemm(
|
||||||
|
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||||
|
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
|
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||||
|
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||||
|
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||||
|
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||||
|
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||||
|
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||||
|
bool is_zp_float) {
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
|
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||||
|
return torch::empty({1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
// For a given "a" of size [M,K] performs a permutation of the K columns based
|
||||||
|
// on the given "perm" indices.
|
||||||
|
template <int moe_block_size>
|
||||||
|
__global__ void permute_cols_kernel(
|
||||||
|
int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr,
|
||||||
|
int4* __restrict__ out_int4_ptr,
|
||||||
|
const int32_t* __restrict__ sorted_token_ids_ptr,
|
||||||
|
const int32_t* __restrict__ expert_ids_ptr,
|
||||||
|
const int32_t* __restrict__ num_tokens_past_padded_ptr, int size_m,
|
||||||
|
int size_k, int top_k) {
|
||||||
|
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
|
||||||
|
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
|
||||||
|
int32_t block_sorted_ids[moe_block_size];
|
||||||
|
int block_num_valid_tokens = 0;
|
||||||
|
int64_t old_expert_id = 0;
|
||||||
|
int64_t expert_id = 0;
|
||||||
|
int row_stride = size_k * sizeof(half) / 16;
|
||||||
|
|
||||||
|
auto read_moe_block_data = [&](int block_id) {
|
||||||
|
block_num_valid_tokens = moe_block_size;
|
||||||
|
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
|
||||||
|
for (int i = 0; i < moe_block_size / 4; i++) {
|
||||||
|
tmp_block_sorted_ids[i] =
|
||||||
|
((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
|
||||||
|
}
|
||||||
|
for (int i = 0; i < moe_block_size; i++) {
|
||||||
|
if (block_sorted_ids[i] >= size_m * top_k) {
|
||||||
|
block_num_valid_tokens = i;
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto permute_row = [&](int row) {
|
||||||
|
int iters = size_k / default_threads;
|
||||||
|
int rest = size_k % default_threads;
|
||||||
|
|
||||||
|
int in_offset = (row / top_k) * row_stride;
|
||||||
|
int out_offset = row * row_stride;
|
||||||
|
|
||||||
|
half const* a_row_half =
|
||||||
|
reinterpret_cast<half const*>(a_int4_ptr + in_offset);
|
||||||
|
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
|
||||||
|
|
||||||
|
int base_k = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < iters; i++) {
|
||||||
|
int cur_k = base_k + threadIdx.x;
|
||||||
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
|
|
||||||
|
base_k += default_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rest) {
|
||||||
|
if (threadIdx.x < rest) {
|
||||||
|
int cur_k = base_k + threadIdx.x;
|
||||||
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
|
||||||
|
old_expert_id = expert_id;
|
||||||
|
int tmp_expert_id = expert_ids_ptr[index];
|
||||||
|
if (tmp_expert_id == -1) continue;
|
||||||
|
expert_id = tmp_expert_id;
|
||||||
|
perm_int_ptr += (expert_id - old_expert_id) * size_k;
|
||||||
|
read_moe_block_data(index);
|
||||||
|
|
||||||
|
for (int i = 0; i < block_num_valid_tokens; i++)
|
||||||
|
permute_row(block_sorted_ids[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int thread_k;
|
||||||
|
int thread_n;
|
||||||
|
int num_threads;
|
||||||
|
} thread_config_t;
|
||||||
|
|
||||||
|
thread_config_t small_batch_thread_configs[] = {
|
||||||
|
// Ordered by priority
|
||||||
|
|
||||||
|
// thread_k, thread_n, num_threads
|
||||||
|
{128, 128, 256},
|
||||||
|
{64, 128, 128}};
|
||||||
|
|
||||||
|
thread_config_t large_batch_thread_configs[] = {
|
||||||
|
// Ordered by priority
|
||||||
|
|
||||||
|
// thread_k, thread_n, num_threads
|
||||||
|
{64, 256, 256},
|
||||||
|
{64, 128, 128}};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int blocks_per_sm;
|
||||||
|
thread_config_t tb_cfg;
|
||||||
|
} exec_config_t;
|
||||||
|
|
||||||
|
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
|
||||||
|
int prob_n, int prob_k, int num_bits, int group_size,
|
||||||
|
bool has_act_order, bool is_k_full) {
|
||||||
|
bool cache_scales_chunk = has_act_order && !is_k_full;
|
||||||
|
|
||||||
|
int tb_n = th_config.thread_n;
|
||||||
|
int tb_k = th_config.thread_k;
|
||||||
|
|
||||||
|
// Get max scale groups per thread-block
|
||||||
|
int tb_groups;
|
||||||
|
if (group_size == -1) {
|
||||||
|
tb_groups = 1;
|
||||||
|
} else if (group_size == 0) {
|
||||||
|
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
|
||||||
|
} else {
|
||||||
|
tb_groups = div_ceil(tb_k, group_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cache_scales_chunk) {
|
||||||
|
int load_groups =
|
||||||
|
tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
|
||||||
|
load_groups = max(load_groups, 32); // We load at least 32 scale groups
|
||||||
|
return load_groups * tb_n * 2;
|
||||||
|
|
||||||
|
} else {
|
||||||
|
int tb_scales = tb_groups * tb_n * 2;
|
||||||
|
|
||||||
|
return tb_scales * pipe_stages;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks,
|
||||||
|
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||||
|
int group_size, bool has_act_order, bool is_k_full,
|
||||||
|
int has_zp, int is_zp_float) {
|
||||||
|
int pack_factor = 32 / num_bits;
|
||||||
|
|
||||||
|
// Get B size
|
||||||
|
int tb_k = th_config.thread_k;
|
||||||
|
int tb_n = th_config.thread_n;
|
||||||
|
int tb_m = thread_m_blocks * 16;
|
||||||
|
|
||||||
|
// shm size for block_sorted_ids/block_topk_weights
|
||||||
|
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
|
||||||
|
int sh_block_meta_size = tb_m * 4 * 2;
|
||||||
|
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
|
||||||
|
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
|
||||||
|
int sh_s_size =
|
||||||
|
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
|
||||||
|
group_size, has_act_order, is_k_full);
|
||||||
|
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
|
||||||
|
int sh_zp_size = 0;
|
||||||
|
if (has_zp) {
|
||||||
|
if (is_zp_float)
|
||||||
|
sh_zp_size = sh_s_size;
|
||||||
|
else if (num_bits == 4)
|
||||||
|
sh_zp_size = sh_s_size / 4;
|
||||||
|
else if (num_bits == 8)
|
||||||
|
sh_zp_size = sh_s_size / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size +
|
||||||
|
sh_g_idx_size + sh_block_meta_size;
|
||||||
|
|
||||||
|
return total_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks,
|
||||||
|
int prob_m, int prob_n, int prob_k, int num_bits,
|
||||||
|
int group_size, bool has_act_order, bool is_k_full,
|
||||||
|
int has_zp, int is_zp_float, int max_shared_mem) {
|
||||||
|
// Sanity
|
||||||
|
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
|
||||||
|
th_config.num_threads == -1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify K/N are divisible by thread K/N
|
||||||
|
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify min for thread K/N
|
||||||
|
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// num_threads must be at least 128 (= 4 warps)
|
||||||
|
if (th_config.num_threads < 128) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that pipeline fits into cache
|
||||||
|
int cache_size = get_kernel_cache_size(
|
||||||
|
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size,
|
||||||
|
has_act_order, is_k_full, has_zp, is_zp_float);
|
||||||
|
return cache_size <= max_shared_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
|
||||||
|
M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||||
|
NUM_THREADS, IS_ZP_FLOAT) \
|
||||||
|
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||||
|
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||||
|
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||||
|
m_block_size_8 == M_BLOCK_SIZE_8 && \
|
||||||
|
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||||
|
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \
|
||||||
|
is_zp_float == IS_ZP_FLOAT) { \
|
||||||
|
kernel = Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
|
||||||
|
THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, \
|
||||||
|
pipe_stages, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||||
|
IS_ZP_FLOAT>; \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \
|
||||||
|
false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
\
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||||
|
NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
\
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
\
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
\
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \
|
||||||
|
NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \
|
||||||
|
false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||||
|
false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \
|
||||||
|
false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||||
|
NUM_THREADS, false)
|
||||||
|
|
||||||
|
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
\
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
\
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, false) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \
|
||||||
|
NUM_THREADS, false)
|
||||||
|
|
||||||
|
// We currently have 4-bit models only with group_blocks == 4
|
||||||
|
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \
|
||||||
|
true) \
|
||||||
|
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, true) \
|
||||||
|
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, true) \
|
||||||
|
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, true) \
|
||||||
|
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \
|
||||||
|
NUM_THREADS, true)
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type,
|
||||||
|
int thread_m_blocks, int thread_n_blocks,
|
||||||
|
int thread_k_blocks, bool m_block_size_8,
|
||||||
|
bool has_act_order, bool has_zp,
|
||||||
|
int group_blocks, int num_threads,
|
||||||
|
bool is_zp_float) {
|
||||||
|
int num_bits = q_type.size_bits();
|
||||||
|
auto kernel = MarlinDefault;
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256)
|
||||||
|
GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128)
|
||||||
|
|
||||||
|
GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256)
|
||||||
|
GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128)
|
||||||
|
|
||||||
|
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256)
|
||||||
|
GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128)
|
||||||
|
|
||||||
|
GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256)
|
||||||
|
GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128)
|
||||||
|
|
||||||
|
AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256)
|
||||||
|
AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128)
|
||||||
|
|
||||||
|
AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256)
|
||||||
|
AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128)
|
||||||
|
|
||||||
|
return kernel;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m,
|
||||||
|
int prob_n, int prob_k, int thread_m_blocks,
|
||||||
|
bool m_block_size_8, int num_bits,
|
||||||
|
int group_size, bool has_act_order,
|
||||||
|
bool is_k_full, bool has_zp,
|
||||||
|
bool is_zp_float, int max_shared_mem) {
|
||||||
|
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
|
||||||
|
thread_config_t* thread_configs = thread_m_blocks > 1
|
||||||
|
? large_batch_thread_configs
|
||||||
|
: small_batch_thread_configs;
|
||||||
|
int thread_configs_size =
|
||||||
|
thread_m_blocks > 1
|
||||||
|
? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
|
||||||
|
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
|
||||||
|
|
||||||
|
int count = 0;
|
||||||
|
constexpr int device_max_reg_size = 255 * 1024;
|
||||||
|
for (int i = 0; i < thread_configs_size; i++) {
|
||||||
|
thread_config_t th_config = thread_configs[i];
|
||||||
|
|
||||||
|
if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k,
|
||||||
|
num_bits, group_size, has_act_order, is_k_full, has_zp,
|
||||||
|
is_zp_float, max_shared_mem)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
int cache_size = get_kernel_cache_size(
|
||||||
|
th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits,
|
||||||
|
group_size, has_act_order, is_k_full, has_zp, is_zp_float);
|
||||||
|
|
||||||
|
int group_blocks = 0;
|
||||||
|
if (!has_act_order) {
|
||||||
|
group_blocks = group_size == -1 ? -1 : group_size / 16;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto kernel = get_marlin_kernel<scalar_t>(
|
||||||
|
q_type, thread_m_blocks, th_config.thread_n / 16,
|
||||||
|
th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp,
|
||||||
|
group_blocks, th_config.num_threads, is_zp_float);
|
||||||
|
|
||||||
|
if (kernel == MarlinDefault) continue;
|
||||||
|
|
||||||
|
if (thread_m_blocks > 1) {
|
||||||
|
exec_cfg = {1, th_config};
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
cudaFuncAttributes attr;
|
||||||
|
cudaFuncGetAttributes(&attr, kernel);
|
||||||
|
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
|
||||||
|
int allow_count = min(device_max_reg_size / reg_size,
|
||||||
|
max_shared_mem / (cache_size + 1024));
|
||||||
|
allow_count = max(min(allow_count, 4), 1);
|
||||||
|
if (allow_count > count) {
|
||||||
|
count = allow_count;
|
||||||
|
exec_cfg = {count, th_config};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return exec_cfg;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||||
|
void* zp, void* g_idx, void* perm, void* a_tmp,
|
||||||
|
void* sorted_token_ids, void* expert_ids,
|
||||||
|
void* num_tokens_past_padded, void* topk_weights,
|
||||||
|
int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep,
|
||||||
|
int prob_m, int prob_n, int prob_k, void* workspace,
|
||||||
|
vllm::ScalarType const& q_type, bool has_act_order,
|
||||||
|
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||||
|
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||||
|
int sms, bool use_atomic_add, bool use_fp32_reduce,
|
||||||
|
bool is_zp_float) {
|
||||||
|
int thread_m_blocks = div_ceil(moe_block_size, 16);
|
||||||
|
bool m_block_size_8 = moe_block_size == 8;
|
||||||
|
|
||||||
|
if (has_zp) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||||
|
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||||
|
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||||
|
q_type.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||||
|
", ", prob_n, ", ", prob_k, "]");
|
||||||
|
|
||||||
|
int group_blocks = 0;
|
||||||
|
if (has_act_order) {
|
||||||
|
if (is_k_full) {
|
||||||
|
TORCH_CHECK(group_size != -1);
|
||||||
|
group_blocks = group_size / 16;
|
||||||
|
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||||
|
" is not divisible by group_blocks = ", group_blocks);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(group_size == 0);
|
||||||
|
group_blocks = 0;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (group_size == -1) {
|
||||||
|
group_blocks = -1;
|
||||||
|
} else {
|
||||||
|
group_blocks = group_size / 16;
|
||||||
|
TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
|
||||||
|
" is not divisible by group_blocks = ", group_blocks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_bits = q_type.size_bits();
|
||||||
|
const int4* A_ptr = (const int4*)A;
|
||||||
|
const int4* B_ptr = (const int4*)B;
|
||||||
|
int4* C_ptr = (int4*)C;
|
||||||
|
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||||
|
const int4* s_ptr = (const int4*)s;
|
||||||
|
const int4* zp_ptr = (const int4*)zp;
|
||||||
|
const int* g_idx_ptr = (const int*)g_idx;
|
||||||
|
const int* perm_ptr = (const int*)perm;
|
||||||
|
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||||
|
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
|
||||||
|
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
|
||||||
|
const int32_t* num_tokens_past_padded_ptr =
|
||||||
|
(const int32_t*)num_tokens_past_padded;
|
||||||
|
const float* topk_weights_ptr = (const float*)topk_weights;
|
||||||
|
int* locks = (int*)workspace;
|
||||||
|
|
||||||
|
if (has_act_order) {
|
||||||
|
// Permute A columns
|
||||||
|
auto kernel = permute_cols_kernel<8>;
|
||||||
|
if (moe_block_size == 8) {
|
||||||
|
} else if (moe_block_size == 16)
|
||||||
|
kernel = permute_cols_kernel<16>;
|
||||||
|
else if (moe_block_size == 32)
|
||||||
|
kernel = permute_cols_kernel<32>;
|
||||||
|
else if (moe_block_size == 48)
|
||||||
|
kernel = permute_cols_kernel<48>;
|
||||||
|
else if (moe_block_size == 64)
|
||||||
|
kernel = permute_cols_kernel<64>;
|
||||||
|
else
|
||||||
|
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
|
||||||
|
|
||||||
|
// avoid ">>>" being formatted to "> > >"
|
||||||
|
// clang-format off
|
||||||
|
kernel<<<sms, default_threads, 0, stream>>>(
|
||||||
|
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
|
||||||
|
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
|
||||||
|
// clang-format on
|
||||||
|
A_ptr = a_tmp_ptr;
|
||||||
|
prob_m = prob_m * top_k;
|
||||||
|
top_k = 1;
|
||||||
|
|
||||||
|
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||||
|
// (since the weight rows are reordered by increasing group ids, and by
|
||||||
|
// having a full K, we have full original groups)
|
||||||
|
if (is_k_full) has_act_order = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int max_shared_mem = 0;
|
||||||
|
cudaDeviceGetAttribute(&max_shared_mem,
|
||||||
|
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||||
|
TORCH_CHECK(max_shared_mem > 0);
|
||||||
|
|
||||||
|
// Set thread config
|
||||||
|
exec_config_t exec_cfg;
|
||||||
|
thread_config_t thread_tfg;
|
||||||
|
if (thread_k != -1 && thread_n != -1) {
|
||||||
|
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
|
||||||
|
exec_cfg = exec_config_t{1, thread_tfg};
|
||||||
|
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
|
||||||
|
" is not divisible by thread_n = ", thread_n);
|
||||||
|
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
|
||||||
|
" is not divisible by thread_k = ", thread_k);
|
||||||
|
} else {
|
||||||
|
// Auto config
|
||||||
|
exec_cfg = determine_exec_config<scalar_t>(
|
||||||
|
q_type, prob_m, prob_n, prob_k, thread_m_blocks, m_block_size_8,
|
||||||
|
num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float,
|
||||||
|
max_shared_mem);
|
||||||
|
thread_tfg = exec_cfg.tb_cfg;
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_threads = thread_tfg.num_threads;
|
||||||
|
thread_k = thread_tfg.thread_k;
|
||||||
|
thread_n = thread_tfg.thread_n;
|
||||||
|
int blocks = sms * exec_cfg.blocks_per_sm;
|
||||||
|
if (exec_cfg.blocks_per_sm > 1)
|
||||||
|
max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
|
||||||
|
|
||||||
|
int thread_k_blocks = thread_k / 16;
|
||||||
|
int thread_n_blocks = thread_n / 16;
|
||||||
|
|
||||||
|
TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n,
|
||||||
|
prob_k, num_bits, group_size, has_act_order,
|
||||||
|
is_k_full, has_zp, is_zp_float, max_shared_mem),
|
||||||
|
"Invalid thread config: thread_m_blocks = ", thread_m_blocks,
|
||||||
|
", thread_k = ", thread_tfg.thread_k,
|
||||||
|
", thread_n = ", thread_tfg.thread_n,
|
||||||
|
", num_threads = ", thread_tfg.num_threads, " for MKN = [",
|
||||||
|
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
|
||||||
|
", group_size = ", group_size,
|
||||||
|
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
|
||||||
|
", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float,
|
||||||
|
", max_shared_mem = ", max_shared_mem);
|
||||||
|
|
||||||
|
auto kernel = get_marlin_kernel<scalar_t>(
|
||||||
|
q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8,
|
||||||
|
has_act_order, has_zp, group_blocks, num_threads, is_zp_float);
|
||||||
|
|
||||||
|
if (kernel == MarlinDefault) {
|
||||||
|
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||||
|
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||||
|
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||||
|
", thread_m_blocks = ", thread_m_blocks,
|
||||||
|
", thread_n_blocks = ", thread_n_blocks,
|
||||||
|
", thread_k_blocks = ", thread_k_blocks,
|
||||||
|
", num_bits = ", num_bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
|
max_shared_mem);
|
||||||
|
// avoid ">>>" being formatted to "> > >"
|
||||||
|
// clang-format off
|
||||||
|
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
|
||||||
|
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
|
||||||
|
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
|
||||||
|
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
|
||||||
|
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace MARLIN_NAMESPACE_NAME
|
||||||
|
|
||||||
|
torch::Tensor moe_wna16_marlin_gemm(
|
||||||
|
torch::Tensor& a, std::optional<torch::Tensor> const& c_or_none,
|
||||||
|
torch::Tensor& b_q_weight, torch::Tensor& b_scales,
|
||||||
|
std::optional<torch::Tensor> const& b_zeros_or_none,
|
||||||
|
std::optional<torch::Tensor> const& g_idx_or_none,
|
||||||
|
std::optional<torch::Tensor> const& perm_or_none, torch::Tensor& workspace,
|
||||||
|
torch::Tensor& sorted_token_ids, torch::Tensor& expert_ids,
|
||||||
|
torch::Tensor& num_tokens_past_padded, torch::Tensor& topk_weights,
|
||||||
|
int64_t moe_block_size, int64_t top_k, bool mul_topk_weights, bool is_ep,
|
||||||
|
vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n,
|
||||||
|
int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce,
|
||||||
|
bool is_zp_float) {
|
||||||
|
vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id);
|
||||||
|
int pack_factor = 32 / b_q_type.size_bits();
|
||||||
|
|
||||||
|
if (moe_block_size != 8) {
|
||||||
|
TORCH_CHECK(moe_block_size % 16 == 0,
|
||||||
|
"unsupported moe_block_size=", moe_block_size);
|
||||||
|
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64,
|
||||||
|
"unsupported moe_block_size=", moe_block_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify A
|
||||||
|
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||||
|
", size_m = ", size_m);
|
||||||
|
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
|
||||||
|
", size_k = ", size_k);
|
||||||
|
|
||||||
|
// Verify B
|
||||||
|
TORCH_CHECK(
|
||||||
|
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k,
|
||||||
|
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||||
|
TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
|
||||||
|
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||||
|
", size_k = ", size_k,
|
||||||
|
", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||||
|
TORCH_CHECK(
|
||||||
|
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
|
||||||
|
"b_q_weight.size(2) = ", b_q_weight.size(2),
|
||||||
|
" is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size);
|
||||||
|
int actual_size_n =
|
||||||
|
(b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
|
||||||
|
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||||
|
", actual_size_n = ", actual_size_n);
|
||||||
|
|
||||||
|
// Verify device and strides
|
||||||
|
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||||
|
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||||
|
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||||
|
|
||||||
|
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||||
|
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||||
|
|
||||||
|
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||||
|
// auto -1)
|
||||||
|
int thread_k = -1;
|
||||||
|
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
|
||||||
|
// auto -1)
|
||||||
|
int thread_n = -1;
|
||||||
|
// sms: number of SMs to use for the kernel
|
||||||
|
int sms = -1;
|
||||||
|
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
|
||||||
|
|
||||||
|
// Alloc buffers
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
|
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||||
|
torch::Tensor c;
|
||||||
|
if (c_or_none.has_value()) {
|
||||||
|
c = c_or_none.value();
|
||||||
|
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
|
||||||
|
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
|
||||||
|
TORCH_CHECK(c.size(0) == size_m * top_k,
|
||||||
|
"Shape mismatch: c.size(0) = ", c.size(0),
|
||||||
|
", size_m * topk = ", size_m * top_k);
|
||||||
|
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1),
|
||||||
|
", size_n = ", size_n);
|
||||||
|
} else {
|
||||||
|
c = torch::empty({size_m * top_k, size_n}, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||||
|
torch::Tensor c_tmp;
|
||||||
|
auto options_fp32 =
|
||||||
|
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||||
|
if (use_fp32_reduce && !use_atomic_add) {
|
||||||
|
// max num of threadblocks is sms * 4
|
||||||
|
long max_c_tmp_size = min(
|
||||||
|
(long)size_n * sorted_token_ids.size(0),
|
||||||
|
(long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
|
||||||
|
if (moe_block_size == 8) max_c_tmp_size *= 2;
|
||||||
|
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
|
||||||
|
} else {
|
||||||
|
c_tmp = torch::empty({0}, options_fp32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect groupsize and act_order
|
||||||
|
int num_groups = -1;
|
||||||
|
int group_size = -1;
|
||||||
|
|
||||||
|
int rank = b_scales.sizes().size();
|
||||||
|
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
|
||||||
|
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
|
||||||
|
" is not size_n = ", size_n);
|
||||||
|
num_groups = b_scales.size(1);
|
||||||
|
|
||||||
|
torch::Tensor g_idx, perm, a_tmp;
|
||||||
|
;
|
||||||
|
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
|
||||||
|
g_idx = g_idx_or_none.value();
|
||||||
|
perm = perm_or_none.value();
|
||||||
|
|
||||||
|
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||||
|
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||||
|
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
|
||||||
|
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
|
||||||
|
|
||||||
|
// Verify g_idx and perm
|
||||||
|
TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) ||
|
||||||
|
(g_idx.size(-1) == size_k && perm.size(-1) == size_k),
|
||||||
|
"Unexpected g_idx.size(-1) = ", g_idx.size(-1),
|
||||||
|
" and perm.size(-1) = ", perm.size(-1),
|
||||||
|
", where size_k = ", size_k);
|
||||||
|
} else {
|
||||||
|
g_idx = torch::empty({0}, options);
|
||||||
|
perm = torch::empty({0}, options);
|
||||||
|
a_tmp = torch::empty({0}, options);
|
||||||
|
}
|
||||||
|
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
|
||||||
|
|
||||||
|
if (has_act_order) {
|
||||||
|
a_tmp = torch::empty({size_m * top_k, size_k}, options);
|
||||||
|
if (is_k_full) {
|
||||||
|
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
|
||||||
|
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
|
||||||
|
", is not divisible by num_groups = ", num_groups);
|
||||||
|
group_size = size_k / num_groups;
|
||||||
|
} else {
|
||||||
|
group_size = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
a_tmp = torch::empty({0}, options);
|
||||||
|
if (num_groups > 1) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
size_k % num_groups == 0, "size_k = ", size_k,
|
||||||
|
", is not divisible by b_scales.size(1) = ", b_scales.size(1));
|
||||||
|
group_size = size_k / num_groups;
|
||||||
|
} else {
|
||||||
|
group_size = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor b_zeros;
|
||||||
|
if (b_zeros_or_none.has_value()) {
|
||||||
|
b_zeros = b_zeros_or_none.value();
|
||||||
|
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||||
|
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||||
|
} else {
|
||||||
|
b_zeros = torch::empty({0}, options);
|
||||||
|
}
|
||||||
|
bool has_zp = b_zeros.size(-1) > 0;
|
||||||
|
|
||||||
|
if (has_zp) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
b_q_type == vllm::kU4,
|
||||||
|
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128,
|
||||||
|
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||||
|
b_q_type.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_zp && is_zp_float) {
|
||||||
|
TORCH_CHECK(a.scalar_type() == at::ScalarType::Half,
|
||||||
|
"Computation type must be float16 (half) when using float zero "
|
||||||
|
"points.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify b_zeros
|
||||||
|
if (has_zp) {
|
||||||
|
int rank = b_zeros.sizes().size();
|
||||||
|
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
|
||||||
|
if (is_zp_float) {
|
||||||
|
TORCH_CHECK(b_zeros.size(2) == size_n,
|
||||||
|
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||||
|
" is not size_n = ", size_n);
|
||||||
|
TORCH_CHECK(num_groups == b_zeros.size(1),
|
||||||
|
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||||
|
" is not num_groups = ", num_groups);
|
||||||
|
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(b_zeros.size(1) == num_groups,
|
||||||
|
"b_zeros dim 1 = ", b_zeros.size(1),
|
||||||
|
" is not num_groups = ", num_groups);
|
||||||
|
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
|
||||||
|
"b_zeros dim 2 = ", b_zeros.size(2),
|
||||||
|
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify workspace size
|
||||||
|
TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
|
||||||
|
"size_n = ", size_n, ", is not divisible by min_thread_n = ",
|
||||||
|
MARLIN_NAMESPACE_NAME::min_thread_n);
|
||||||
|
|
||||||
|
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
|
||||||
|
int min_workspace_size = min(
|
||||||
|
max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
|
||||||
|
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||||
|
"workspace.numel = ", workspace.numel(),
|
||||||
|
" is below min_workspace_size = ", min_workspace_size);
|
||||||
|
|
||||||
|
int dev = a.get_device();
|
||||||
|
if (a.scalar_type() == at::ScalarType::Half) {
|
||||||
|
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
|
||||||
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
|
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||||
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||||
|
a_tmp.data_ptr<at::Half>(), sorted_token_ids.data_ptr(),
|
||||||
|
expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(),
|
||||||
|
topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep,
|
||||||
|
size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order,
|
||||||
|
is_k_full, has_zp, num_groups, group_size, dev,
|
||||||
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
|
use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
|
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||||
|
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
|
||||||
|
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||||
|
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||||
|
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
|
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||||
|
sorted_token_ids.data_ptr(), expert_ids.data_ptr(),
|
||||||
|
num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(),
|
||||||
|
moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k,
|
||||||
|
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||||
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
|
thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(false,
|
||||||
|
"moe_wna16_marlin_gemm only supports bfloat16 and float16");
|
||||||
|
}
|
||||||
|
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("moe_wna16_marlin_gemm", &moe_wna16_marlin_gemm);
|
||||||
|
}
|
||||||
@ -13,7 +13,6 @@
|
|||||||
template <typename scalar_t, int bit, int GROUPS>
|
template <typename scalar_t, int bit, int GROUPS>
|
||||||
__global__ void moe_wna16_gemm_kernel(
|
__global__ void moe_wna16_gemm_kernel(
|
||||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||||
|
|
||||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||||
const uint32_t* __restrict__ qzeros,
|
const uint32_t* __restrict__ qzeros,
|
||||||
|
|
||||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
|||||||
if (token_index / top_k >= size_m) break;
|
if (token_index / top_k >= size_m) break;
|
||||||
|
|
||||||
num_valid_tokens = m + 1;
|
num_valid_tokens = m + 1;
|
||||||
if (blockIdx.z == 0 && offset_n < size_n)
|
|
||||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
|
||||||
|
|
||||||
if (expert_id != -1) {
|
if (expert_id != -1) {
|
||||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
auto options =
|
output.zero_();
|
||||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
|
||||||
|
|
||||||
const int num_experts = b_qweight.size(0);
|
const int num_experts = b_qweight.size(0);
|
||||||
const int size_m = input.size(0);
|
const int size_m = input.size(0);
|
||||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
const uint32_t* b_qzeros_ptr;
|
const uint32_t* b_qzeros_ptr;
|
||||||
if (b_qzeros.has_value())
|
if (b_qzeros.has_value())
|
||||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||||
const float* topk_weights_ptr;
|
const float* topk_weights_ptr = nullptr;
|
||||||
if (topk_weights.has_value())
|
if (topk_weights.has_value())
|
||||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||||
|
|
||||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||||
|
|||||||
@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
m.impl("moe_wna16_gemm", torch::kCUDA, &moe_wna16_gemm);
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||||
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
|
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
|
||||||
"int b_q_type, SymInt size_m, "
|
"Tensor sorted_token_ids,"
|
||||||
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
|
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
|
||||||
"topk, "
|
"Tensor! topk_weights, int moe_block_size, int top_k, "
|
||||||
"int moe_block_size, bool replicate_input, bool apply_weights)"
|
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
|
||||||
" -> Tensor");
|
"int size_m, int size_n, int size_k,"
|
||||||
|
"bool is_full_k, bool use_atomic_add,"
|
||||||
|
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
|
||||||
|
|
||||||
// conditionally compiled so impl registration is in source file
|
// conditionally compiled so impl registration is in source file
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
37
csrc/ops.h
37
csrc/ops.h
@ -52,6 +52,15 @@ void paged_attention_v2(
|
|||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
void merge_attn_states(torch::Tensor& output,
|
||||||
|
std::optional<torch::Tensor> output_lse,
|
||||||
|
const torch::Tensor& prefix_output,
|
||||||
|
const torch::Tensor& prefix_lse,
|
||||||
|
const torch::Tensor& suffix_output,
|
||||||
|
const torch::Tensor& suffix_lse);
|
||||||
|
#endif
|
||||||
|
|
||||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||||
double epsilon);
|
double epsilon);
|
||||||
|
|
||||||
@ -119,6 +128,8 @@ void advance_step_flashinfer(
|
|||||||
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
|
||||||
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds);
|
||||||
|
|
||||||
|
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||||
const torch::Tensor& codebooks,
|
const torch::Tensor& codebooks,
|
||||||
@ -143,7 +154,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||||
int64_t n);
|
int64_t n,
|
||||||
|
std::optional<at::ScalarType> const& dtype);
|
||||||
|
|
||||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||||
int64_t type, int64_t row);
|
int64_t type, int64_t row);
|
||||||
@ -164,6 +176,7 @@ int64_t ggml_moe_get_block_size(int64_t type);
|
|||||||
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||||
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
|
||||||
|
bool cutlass_group_gemm_supported(int64_t cuda_device_capability);
|
||||||
|
|
||||||
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
|
||||||
torch::Tensor const& B, torch::Tensor const& A_sf,
|
torch::Tensor const& B, torch::Tensor const& A_sf,
|
||||||
@ -175,6 +188,19 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_moe_mm(
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_data(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
@ -251,10 +277,10 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
|||||||
const std::optional<at::Tensor>& has_initial_state,
|
const std::optional<at::Tensor>& has_initial_state,
|
||||||
bool silu_activation, int64_t pad_slot_id);
|
bool silu_activation, int64_t pad_slot_id);
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
using fptr_t = int64_t;
|
using fptr_t = int64_t;
|
||||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||||
torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
|
torch::Tensor& rank_data, int64_t rank,
|
||||||
|
bool fully_connected);
|
||||||
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||||
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
fptr_t reg_buffer, int64_t reg_buffer_sz_bytes);
|
||||||
void dispose(fptr_t _fa);
|
void dispose(fptr_t _fa);
|
||||||
@ -265,4 +291,7 @@ get_graph_buffer_ipc_meta(fptr_t _fa);
|
|||||||
void register_graph_buffers(fptr_t _fa,
|
void register_graph_buffers(fptr_t _fa,
|
||||||
const std::vector<std::vector<int64_t>>& handles,
|
const std::vector<std::vector<int64_t>>& handles,
|
||||||
const std::vector<std::vector<int64_t>>& offsets);
|
const std::vector<std::vector<int64_t>>& offsets);
|
||||||
#endif
|
std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||||
|
int64_t size);
|
||||||
|
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||||
|
void free_shared_buffer(int64_t buffer);
|
||||||
|
|||||||
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
80
csrc/quantization/cutlass_w8a8/moe/get_group_starts.cuh
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
#include <c10/cuda/CUDAStream.h>
|
||||||
|
|
||||||
|
#include "core/scalar_type.hpp"
|
||||||
|
#include "cutlass/bfloat16.h"
|
||||||
|
#include "cutlass/float8.h"
|
||||||
|
|
||||||
|
template <typename ElementAB, typename ElementC, typename ElementAccumulator>
|
||||||
|
__global__ void get_group_gemm_starts(
|
||||||
|
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||||
|
ElementC** out_offsets, ElementAccumulator** a_scales_offsets,
|
||||||
|
ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int,
|
||||||
|
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||||
|
ElementAccumulator* a_scales_base_as_int,
|
||||||
|
ElementAccumulator* b_scales_base_as_int, int64_t n, int64_t k,
|
||||||
|
bool per_act_token, bool per_out_ch) {
|
||||||
|
int expert_id = threadIdx.x;
|
||||||
|
|
||||||
|
int64_t expert_offset = expert_offsets[expert_id];
|
||||||
|
|
||||||
|
a_offsets[expert_id] = a_base_as_int + expert_offset * k;
|
||||||
|
b_offsets[expert_id] = b_base_as_int + expert_id * k * n;
|
||||||
|
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||||
|
a_scales_offsets[expert_id] =
|
||||||
|
a_scales_base_as_int + (per_act_token ? expert_offset : 0);
|
||||||
|
b_scales_offsets[expert_id] =
|
||||||
|
b_scales_base_as_int + (per_out_ch ? n * expert_id : expert_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE) \
|
||||||
|
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||||
|
get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float> \
|
||||||
|
<<<1, num_experts, 0, stream>>>( \
|
||||||
|
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||||
|
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||||
|
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||||
|
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||||
|
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||||
|
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||||
|
static_cast<float*>(a_scales.data_ptr()), \
|
||||||
|
static_cast<float*>(b_scales.data_ptr()), out_tensors.size(1), \
|
||||||
|
a_tensors.size(1), per_act_token, per_out_ch); \
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
void run_get_group_gemm_starts(
|
||||||
|
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||||
|
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||||
|
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||||
|
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales) {
|
||||||
|
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
bool per_act_token = a_scales.numel() != 1;
|
||||||
|
bool per_out_ch = b_scales.numel() != num_experts;
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||||
|
|
||||||
|
if (false) {
|
||||||
|
}
|
||||||
|
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t)
|
||||||
|
__CALL_GET_STARTS_KERNEL(torch::kFloat16, half)
|
||||||
|
else {
|
||||||
|
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
160
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
#include "grouped_mm_c3x.cuh"
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_default {
|
||||||
|
// M in (16, inf)
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
|
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
||||||
|
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_M16 {
|
||||||
|
// M in [1, 16]
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
|
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
|
||||||
|
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_K8192 {
|
||||||
|
// K in [8192, inf)
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
|
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
||||||
|
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType,
|
||||||
|
template <typename, typename, typename> typename Epilogue>
|
||||||
|
struct sm90_fp8_config_N8192 {
|
||||||
|
// N in [8192, inf)
|
||||||
|
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||||
|
using KernelSchedule =
|
||||||
|
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||||
|
using EpilogueSchedule =
|
||||||
|
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||||
|
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
||||||
|
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||||
|
|
||||||
|
using Cutlass3xGemm =
|
||||||
|
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||||
|
KernelSchedule, EpilogueSchedule>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename InType, typename OutType>
|
||||||
|
void run_cutlass_moe_mm_sm90(
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||||
|
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||||
|
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||||
|
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||||
|
|
||||||
|
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||||
|
"A tensors must be of type float8_e4m3fn.");
|
||||||
|
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||||
|
"B tensors must be of type float8_e4m3fn.");
|
||||||
|
|
||||||
|
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||||
|
|
||||||
|
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||||
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||||
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
|
||||||
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
|
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
||||||
|
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||||
|
|
||||||
|
uint32_t const m = a_tensors.size(0);
|
||||||
|
uint32_t const n = out_tensors.size(1);
|
||||||
|
uint32_t const k = a_tensors.size(1);
|
||||||
|
|
||||||
|
if (n >= 8192) {
|
||||||
|
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides);
|
||||||
|
} else if (k >= 8192) {
|
||||||
|
cutlass_group_gemm_caller<Cutlass3xGemmK8192>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides);
|
||||||
|
} else if (m <= 16) {
|
||||||
|
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides);
|
||||||
|
} else {
|
||||||
|
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void dispatch_moe_mm_sm90(
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||||
|
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||||
|
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides);
|
||||||
|
} else {
|
||||||
|
run_cutlass_moe_mm_sm90<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||||
|
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||||
|
problem_sizes, a_strides, b_strides, c_strides);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void cutlass_moe_mm_sm90(
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||||
|
dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides);
|
||||||
|
}
|
||||||
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
149
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "cutlass/cutlass.h"
|
||||||
|
|
||||||
|
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||||
|
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||||
|
|
||||||
|
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||||
|
#include "cutlass_extensions/common.hpp"
|
||||||
|
#include "get_group_starts.cuh"
|
||||||
|
|
||||||
|
using namespace cute;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ProblemShape =
|
||||||
|
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||||
|
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
using ArchTag = cutlass::arch::Sm90;
|
||||||
|
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||||
|
|
||||||
|
using LayoutA = cutlass::layout::RowMajor;
|
||||||
|
using LayoutB = cutlass::layout::ColumnMajor;
|
||||||
|
using LayoutC = cutlass::layout::RowMajor;
|
||||||
|
|
||||||
|
template <typename ElementAB_, typename ElementC_,
|
||||||
|
template <typename, typename, typename> typename Epilogue_,
|
||||||
|
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||||
|
typename EpilogueSchedule>
|
||||||
|
struct cutlass_3x_group_gemm {
|
||||||
|
using ElementAB = ElementAB_;
|
||||||
|
using ElementC = void;
|
||||||
|
using ElementD = ElementC_;
|
||||||
|
using ElementAccumulator = float;
|
||||||
|
|
||||||
|
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||||
|
|
||||||
|
using StrideC =
|
||||||
|
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
|
||||||
|
|
||||||
|
static constexpr int AlignmentAB =
|
||||||
|
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||||
|
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||||
|
|
||||||
|
using EVTCompute = typename Epilogue::EVTCompute;
|
||||||
|
|
||||||
|
using CollectiveEpilogue =
|
||||||
|
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||||
|
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||||
|
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
||||||
|
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||||
|
|
||||||
|
static constexpr size_t CEStorageSize =
|
||||||
|
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||||
|
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||||
|
static_cast<int>(CEStorageSize)>;
|
||||||
|
|
||||||
|
using CollectiveMainloop =
|
||||||
|
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||||
|
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
||||||
|
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||||
|
Stages, KernelSchedule>::CollectiveOp;
|
||||||
|
|
||||||
|
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||||
|
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||||
|
|
||||||
|
struct GemmKernel : public KernelType {};
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Gemm>
|
||||||
|
void cutlass_group_gemm_caller(
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||||
|
using ElementAB = typename Gemm::ElementAB;
|
||||||
|
using ElementD = typename Gemm::ElementD;
|
||||||
|
|
||||||
|
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||||
|
int k_size = a_tensors.size(1);
|
||||||
|
int n_size = out_tensors.size(1);
|
||||||
|
|
||||||
|
bool per_act_token = a_scales.numel() != 1;
|
||||||
|
bool per_out_ch = b_scales.numel() != num_experts;
|
||||||
|
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||||
|
|
||||||
|
auto options_int =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt64).device(a_tensors.device());
|
||||||
|
|
||||||
|
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
|
||||||
|
|
||||||
|
run_get_group_gemm_starts(expert_offsets, a_ptrs, b_ptrs, out_ptrs,
|
||||||
|
a_scales_ptrs, b_scales_ptrs, a_tensors, b_tensors,
|
||||||
|
out_tensors, a_scales, b_scales);
|
||||||
|
|
||||||
|
using GemmKernel = typename Gemm::GemmKernel;
|
||||||
|
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||||
|
using StrideC = typename GemmKernel::InternalStrideC;
|
||||||
|
|
||||||
|
ProblemShape::UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||||
|
static_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||||
|
problem_sizes.data_ptr());
|
||||||
|
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||||
|
|
||||||
|
typename GemmKernel::MainloopArguments mainloop_args{
|
||||||
|
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||||
|
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||||
|
|
||||||
|
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||||
|
// and on either all or none b_scales
|
||||||
|
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||||
|
Gemm::Epilogue::prepare_args(
|
||||||
|
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||||
|
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||||
|
per_act_token, per_out_ch),
|
||||||
|
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
||||||
|
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||||
|
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||||
|
|
||||||
|
typename GemmKernel::Arguments args{
|
||||||
|
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||||
|
epilogue_args};
|
||||||
|
|
||||||
|
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||||
|
GemmOp gemm_op;
|
||||||
|
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||||
|
|
||||||
|
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||||
|
auto const workspace_options =
|
||||||
|
torch::TensorOptions().dtype(torch::kUInt8).device(a_tensors.device());
|
||||||
|
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||||
|
|
||||||
|
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||||
|
CUTLASS_CHECK(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
90
csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
#include <cudaTypedefs.h>
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||||
|
|
||||||
|
__global__ void compute_problem_sizes(const int* __restrict__ topk_ids,
|
||||||
|
int32_t* problem_sizes1,
|
||||||
|
int32_t* problem_sizes2,
|
||||||
|
int32_t* atomic_buffer,
|
||||||
|
const int topk_length, const int n,
|
||||||
|
const int k) {
|
||||||
|
int expert_id = blockIdx.x;
|
||||||
|
|
||||||
|
int occurrences = 0;
|
||||||
|
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||||
|
occurrences += (topk_ids[i] == expert_id);
|
||||||
|
}
|
||||||
|
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
int final_occurrences = atomic_buffer[expert_id];
|
||||||
|
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||||
|
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||||
|
problem_sizes1[expert_id * 3 + 2] = k;
|
||||||
|
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||||
|
problem_sizes2[expert_id * 3 + 1] = k;
|
||||||
|
problem_sizes2[expert_id * 3 + 2] = n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void compute_expert_offsets(
|
||||||
|
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||||
|
int32_t* atomic_buffer, const int num_experts) {
|
||||||
|
int32_t tot_offset = 0;
|
||||||
|
expert_offsets[0] = 0;
|
||||||
|
for (int i = 0; i < num_experts; ++i) {
|
||||||
|
atomic_buffer[i] = tot_offset;
|
||||||
|
tot_offset += problem_sizes1[i * 3];
|
||||||
|
expert_offsets[i + 1] = tot_offset;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void compute_arg_sorts(const int* __restrict__ topk_ids,
|
||||||
|
int32_t* input_permutation,
|
||||||
|
int32_t* output_permutation,
|
||||||
|
int32_t* atomic_buffer, const int topk_length,
|
||||||
|
const int topk) {
|
||||||
|
int expert_id = blockIdx.x;
|
||||||
|
|
||||||
|
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
||||||
|
if (topk_ids[i] == expert_id) {
|
||||||
|
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
||||||
|
input_permutation[start] = i / topk;
|
||||||
|
output_permutation[i] = start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_data_caller(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||||
|
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
||||||
|
auto options_int32 =
|
||||||
|
torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
||||||
|
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||||
|
|
||||||
|
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||||
|
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||||
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||||
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||||
|
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||||
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||||
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||||
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||||
|
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||||
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||||
|
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||||
|
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||||
|
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(),
|
||||||
|
topk_ids.size(1));
|
||||||
|
}
|
||||||
@ -29,6 +29,20 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
|
|
||||||
|
void cutlass_moe_mm_sm90(
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides);
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_data_caller(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts, const int64_t n, const int64_t k);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||||
@ -102,6 +116,19 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||||
|
// CUTLASS groped FP8 kernels need at least CUDA 12.3
|
||||||
|
// and SM90 (Hopper)
|
||||||
|
|
||||||
|
#if defined CUDA_VERSION
|
||||||
|
if (cuda_device_capability == 90) {
|
||||||
|
return CUDA_VERSION >= 12030;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||||
torch::Tensor const& b_scales,
|
torch::Tensor const& b_scales,
|
||||||
@ -168,6 +195,46 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
version_num);
|
version_num);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void cutlass_moe_mm(
|
||||||
|
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||||
|
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||||
|
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||||
|
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||||
|
torch::Tensor const& b_strides, torch::Tensor const& c_strides) {
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||||
|
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
|
c_strides);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||||
|
". Required capability: 90");
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_cutlass_moe_mm_data(
|
||||||
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
|
torch::Tensor& input_permutation, torch::Tensor& output_permutation,
|
||||||
|
const int64_t num_experts, const int64_t n, const int64_t k) {
|
||||||
|
// This function currently gets compiled only if we have a valid cutlass moe
|
||||||
|
// mm to run it for.
|
||||||
|
int32_t version_num = get_sm_version_num();
|
||||||
|
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||||
|
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||||
|
problem_sizes2, input_permutation,
|
||||||
|
output_permutation, num_experts, n, k);
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
|
false,
|
||||||
|
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||||
|
"CUDA device capability: ",
|
||||||
|
version_num, ". Required capability: 90");
|
||||||
|
}
|
||||||
|
|
||||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||||
torch::Tensor const& b,
|
torch::Tensor const& b,
|
||||||
torch::Tensor const& a_scales,
|
torch::Tensor const& a_scales,
|
||||||
|
|||||||
@ -30,9 +30,6 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
float const min_scaling_factor =
|
|
||||||
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
|
|
||||||
|
|
||||||
int const tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
int const token_idx = blockIdx.x;
|
int const token_idx = blockIdx.x;
|
||||||
|
|
||||||
@ -67,8 +64,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
token_scale = block_absmax_val_maybe;
|
token_scale = block_absmax_val_maybe;
|
||||||
}
|
}
|
||||||
// token scale computation
|
// token scale computation
|
||||||
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
|
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
|
||||||
min_scaling_factor);
|
min_scaling_factor<fp8_type>::val());
|
||||||
scale[token_idx] = token_scale;
|
scale[token_idx] = token_scale;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|||||||
@ -1,20 +1,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "quantization/vectorization.cuh"
|
||||||
|
#include "quantization/utils.cuh"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <c10/core/ScalarType.h>
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
|
||||||
#else
|
|
||||||
#include <ATen/hip/HIPContext.h>
|
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
|
||||||
#include "amd/quant_utils.cuh"
|
#include "amd/quant_utils.cuh"
|
||||||
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
|
||||||
#define MAYBE_HOST_DEVICE
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Determines the preferred FP8 type for the current platform.
|
// Determines the preferred FP8 type for the current platform.
|
||||||
@ -31,29 +23,6 @@ static bool is_fp8_ocp() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct fp8_e4m3_adjusted_max;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
|
|
||||||
static constexpr c10::Float8_e4m3fn val() {
|
|
||||||
return std::numeric_limits<c10::Float8_e4m3fn>::max();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
|
||||||
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
|
||||||
template <>
|
|
||||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
|
|
||||||
static constexpr c10::Float8_e4m3fnuz val() {
|
|
||||||
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
|
|
||||||
fp8_e4m3_adjusted_max<T>::val();
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
@ -76,8 +45,8 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
|||||||
x = val / scale;
|
x = val / scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
float r =
|
||||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
return static_cast<fp8_type>(r);
|
return static_cast<fp8_type>(r);
|
||||||
#else
|
#else
|
||||||
@ -123,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
// atomically write the max to the target location
|
// atomically write the max to the target location
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
|
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -14,8 +14,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
|||||||
float* __restrict__ scales, // [num_tokens]
|
float* __restrict__ scales, // [num_tokens]
|
||||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||||
float const* scale_ub, float const var_epsilon,
|
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
|
||||||
scalar_t* __restrict__ residual = nullptr) {
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
float rms = 0.0f;
|
float rms = 0.0f;
|
||||||
float token_scale = 0.0f;
|
float token_scale = 0.0f;
|
||||||
@ -27,8 +26,8 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
|||||||
// Compute scale
|
// Compute scale
|
||||||
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
||||||
has_residual>(
|
has_residual>(
|
||||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||||
hidden_size, residual);
|
residual);
|
||||||
|
|
||||||
// RMS Norm + Quant
|
// RMS Norm + Quant
|
||||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||||
@ -50,8 +49,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
|||||||
float* __restrict__ scales, // [num_tokens]
|
float* __restrict__ scales, // [num_tokens]
|
||||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||||
float const* scale_ub, float const var_epsilon,
|
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
|
||||||
scalar_t* __restrict__ residual = nullptr) {
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
// For vectorization, token_input and token_output pointers need to be
|
// For vectorization, token_input and token_output pointers need to be
|
||||||
// aligned at 8-byte and 4-byte addresses respectively.
|
// aligned at 8-byte and 4-byte addresses respectively.
|
||||||
@ -60,8 +58,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
|||||||
if (can_vectorize) {
|
if (can_vectorize) {
|
||||||
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
||||||
has_residual>(
|
has_residual>(
|
||||||
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
|
out, scales, input, weight, scale_ub, var_epsilon, hidden_size,
|
||||||
hidden_size, residual);
|
residual);
|
||||||
}
|
}
|
||||||
|
|
||||||
float rms = 0.0f;
|
float rms = 0.0f;
|
||||||
@ -72,8 +70,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
|||||||
var_epsilon, residual);
|
var_epsilon, residual);
|
||||||
// Compute Scale
|
// Compute Scale
|
||||||
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
||||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||||
hidden_size, residual);
|
residual);
|
||||||
|
|
||||||
// RMS Norm + Quant
|
// RMS Norm + Quant
|
||||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||||
@ -105,11 +103,6 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
const float min_scaling_factor =
|
|
||||||
out.dtype() == torch::kInt8
|
|
||||||
? std::numeric_limits<float>::epsilon()
|
|
||||||
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);
|
|
||||||
|
|
||||||
if (residual.has_value()) {
|
if (residual.has_value()) {
|
||||||
VLLM_DISPATCH_QUANT_TYPES(
|
VLLM_DISPATCH_QUANT_TYPES(
|
||||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||||
@ -119,8 +112,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
|||||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
var_epsilon, min_scaling_factor, hidden_size,
|
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>());
|
||||||
residual->data_ptr<scalar_in_t>());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
@ -132,7 +124,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
|||||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
var_epsilon, min_scaling_factor, hidden_size, nullptr);
|
var_epsilon, hidden_size, nullptr);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "quantization/vectorization.cuh"
|
||||||
|
#include "quantization/utils.cuh"
|
||||||
#include "quant_conversions.cuh"
|
#include "quant_conversions.cuh"
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@ -24,7 +25,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
|||||||
// sum of squares
|
// sum of squares
|
||||||
float ss = 0.0f;
|
float ss = 0.0f;
|
||||||
|
|
||||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
float x = static_cast<float>(input[token_offset + i]);
|
float x = static_cast<float>(input[token_offset + i]);
|
||||||
if constexpr (has_residual) {
|
if constexpr (has_residual) {
|
||||||
x += static_cast<float>(residual[token_offset + i]);
|
x += static_cast<float>(residual[token_offset + i]);
|
||||||
@ -51,14 +52,14 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||||
float const rms, float const* __restrict__ scale_ub,
|
float const rms, float const* __restrict__ scale_ub,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
int32_t const hidden_size,
|
||||||
scalar_t const* __restrict__ residual = nullptr) {
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
;
|
;
|
||||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||||
|
|
||||||
float block_absmax_val_maybe = 0.0f;
|
float block_absmax_val_maybe = 0.0f;
|
||||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
float x = static_cast<float>(input[token_offset + i]);
|
float x = static_cast<float>(input[token_offset + i]);
|
||||||
if constexpr (has_residual) {
|
if constexpr (has_residual) {
|
||||||
x += static_cast<float>(residual[token_offset + i]);
|
x += static_cast<float>(residual[token_offset + i]);
|
||||||
@ -83,7 +84,7 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
scale = block_absmax_val_maybe;
|
scale = block_absmax_val_maybe;
|
||||||
}
|
}
|
||||||
// token scale computation
|
// token scale computation
|
||||||
scale = max(scale / qmax, min_scaling_factor);
|
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||||
s_token_scale = scale; // Shared memory store
|
s_token_scale = scale; // Shared memory store
|
||||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||||
}
|
}
|
||||||
@ -103,7 +104,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
|||||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
;
|
;
|
||||||
|
|
||||||
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
float x = static_cast<float>(input[token_offset + i]);
|
float x = static_cast<float>(input[token_offset + i]);
|
||||||
if constexpr (has_residual) {
|
if constexpr (has_residual) {
|
||||||
x += static_cast<float>(residual[token_offset + i]);
|
x += static_cast<float>(residual[token_offset + i]);
|
||||||
@ -142,7 +143,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
|||||||
int32_t const num_vec_elems = hidden_size >> 2;
|
int32_t const num_vec_elems = hidden_size >> 2;
|
||||||
|
|
||||||
#pragma unroll 4
|
#pragma unroll 4
|
||||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||||
vec4_t<scalar_t> in = vec_input[i];
|
vec4_t<scalar_t> in = vec_input[i];
|
||||||
|
|
||||||
vec4_t<float> x;
|
vec4_t<float> x;
|
||||||
@ -184,7 +185,7 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||||
float const rms, float const* __restrict__ scale_ub,
|
float const rms, float const* __restrict__ scale_ub,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
int32_t const hidden_size,
|
||||||
scalar_t const* __restrict__ residual = nullptr) {
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
;
|
;
|
||||||
@ -200,13 +201,13 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||||
|
|
||||||
int32_t const num_vec_elems = hidden_size >> 2;
|
int32_t const num_vec_elems = hidden_size >> 2;
|
||||||
float block_absmax_val_maybe = 0.0f;
|
float block_absmax_val_maybe = 0.0f;
|
||||||
|
|
||||||
#pragma unroll 4
|
#pragma unroll 4
|
||||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||||
vec4_t<scalar_t> in = vec_input[i];
|
vec4_t<scalar_t> in = vec_input[i];
|
||||||
vec4_t<scalar_t> const w = vec_weight[i];
|
vec4_t<scalar_t> const w = vec_weight[i];
|
||||||
|
|
||||||
@ -248,7 +249,7 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
scale = block_absmax_val_maybe;
|
scale = block_absmax_val_maybe;
|
||||||
}
|
}
|
||||||
// token scale computation
|
// token scale computation
|
||||||
scale = max(scale / qmax, min_scaling_factor);
|
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||||
s_token_scale = scale; // shared memory store
|
s_token_scale = scale; // shared memory store
|
||||||
all_token_scales[blockIdx.x] = scale; // global output store
|
all_token_scales[blockIdx.x] = scale; // global output store
|
||||||
}
|
}
|
||||||
@ -286,7 +287,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
|||||||
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
|
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
|
||||||
// replace scaled_fp8_conversion_vec
|
// replace scaled_fp8_conversion_vec
|
||||||
#pragma unroll 4
|
#pragma unroll 4
|
||||||
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||||
vec4_t<scalar_t> const in = vec_input[i];
|
vec4_t<scalar_t> const in = vec_input[i];
|
||||||
vec4_t<scalar_t> const w = vec_weight[i];
|
vec4_t<scalar_t> const w = vec_weight[i];
|
||||||
|
|
||||||
|
|||||||
@ -33,8 +33,8 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
|||||||
|
|
||||||
template <typename fp8_type>
|
template <typename fp8_type>
|
||||||
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
||||||
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
float const r =
|
||||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||||
return static_cast<fp8_type>(r);
|
return static_cast<fp8_type>(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -94,17 +94,17 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
|||||||
dfloat2 v;
|
dfloat2 v;
|
||||||
dequantize_kernel(vx, ib, iqs, v);
|
dequantize_kernel(vx, ib, iqs, v);
|
||||||
|
|
||||||
y[iybs + iqs + 0] = v.x;
|
y[iybs + iqs + 0] = convert_from_half<dst_t>(v.x);
|
||||||
y[iybs + iqs + y_offset] = v.y;
|
y[iybs + iqs + y_offset] = convert_from_half<dst_t>(v.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_q2_K * x = (const block_q2_K *) vx;
|
const block_q2_K * x = (const block_q2_K *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int n = tid/32;
|
const int n = tid/32;
|
||||||
const int l = tid - 32*n;
|
const int l = tid - 32*n;
|
||||||
const int is = 8*n + l/16;
|
const int is = 8*n + l/16;
|
||||||
@ -114,19 +114,19 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
|
|||||||
|
|
||||||
half dall = __low2half(x[i].dm);
|
half dall = __low2half(x[i].dm);
|
||||||
half dmin = __high2half(x[i].dm);
|
half dmin = __high2half(x[i].dm);
|
||||||
y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4)));
|
y[l+ 0] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))));
|
||||||
y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4)));
|
y[l+32] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))));
|
||||||
y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4)));
|
y[l+64] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))));
|
||||||
y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4)));
|
y[l+96] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_q3_K * x = (const block_q3_K *) vx;
|
const block_q3_K * x = (const block_q3_K *) vx;
|
||||||
|
|
||||||
const int r = threadIdx.x/4;
|
const auto r = threadIdx.x/4;
|
||||||
const int tid = r/2;
|
const int tid = r/2;
|
||||||
const int is0 = r%2;
|
const int is0 = r%2;
|
||||||
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
||||||
@ -148,7 +148,9 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
|
|||||||
const uint8_t * q = x[i].qs + 32*n;
|
const uint8_t * q = x[i].qs + 32*n;
|
||||||
const uint8_t * hm = x[i].hmask;
|
const uint8_t * hm = x[i].hmask;
|
||||||
|
|
||||||
for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)));
|
for (int l = l0; l < l0+4; ++l) {
|
||||||
|
y[l] = convert_from_half<dst_t>(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
||||||
@ -164,10 +166,10 @@ template<typename dst_t>
|
|||||||
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q4_K * x = (const block_q4_K *) vx;
|
const block_q4_K * x = (const block_q4_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
|
|
||||||
// assume 32 threads
|
// assume 32 threads
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8;
|
const int il = tid/8;
|
||||||
const int ir = tid%8;
|
const int ir = tid%8;
|
||||||
const int is = 2*il;
|
const int is = 2*il;
|
||||||
@ -188,8 +190,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
|
|||||||
const half d2 = __hmul(dall, __int2half_rn(sc));
|
const half d2 = __hmul(dall, __int2half_rn(sc));
|
||||||
const half m2 = __hmul(dmin, __int2half_rn(m));
|
const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1);
|
y[l + 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1));
|
||||||
y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2);
|
y[l +32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -197,10 +199,10 @@ template<typename dst_t>
|
|||||||
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q5_K * x = (const block_q5_K *) vx;
|
const block_q5_K * x = (const block_q5_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
|
|
||||||
// assume 64 threads - this is very slightly better than the one below
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/16; // il is in 0...3
|
const int il = tid/16; // il is in 0...3
|
||||||
const int ir = tid%16; // ir is in 0...15
|
const int ir = tid%16; // ir is in 0...15
|
||||||
const int is = 2*il; // is is in 0...6
|
const int is = 2*il; // is is in 0...6
|
||||||
@ -220,21 +222,21 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
|
|||||||
const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
|
const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
|
||||||
|
|
||||||
uint8_t hm = 1 << (2*il);
|
uint8_t hm = 1 << (2*il);
|
||||||
y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1);
|
y[ 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1));
|
||||||
y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1);
|
y[ 1] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1));
|
||||||
hm <<= 1;
|
hm <<= 1;
|
||||||
y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2);
|
y[32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2));
|
||||||
y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2);
|
y[33] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const block_q6_K * x = (const block_q6_K *) vx;
|
const block_q6_K * x = (const block_q6_K *) vx;
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
|
|
||||||
// assume 64 threads - this is very slightly better than the one below
|
// assume 64 threads - this is very slightly better than the one below
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int ip = tid/32; // ip is 0 or 1
|
const int ip = tid/32; // ip is 0 or 1
|
||||||
const int il = tid - 32*ip; // 0...32
|
const int il = tid - 32*ip; // 0...32
|
||||||
const int is = 8*ip + il/16;
|
const int is = 8*ip + il/16;
|
||||||
@ -247,19 +249,19 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
|
|||||||
const uint8_t qh = x[i].qh[32*ip + il];
|
const uint8_t qh = x[i].qh[32*ip + il];
|
||||||
const int8_t * sc = x[i].scales + is;
|
const int8_t * sc = x[i].scales + is;
|
||||||
|
|
||||||
y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
|
y[ 0] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))));
|
||||||
y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
|
y[32] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))));
|
||||||
y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
|
y[64] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))));
|
||||||
y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
|
y[96] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
@ -269,16 +271,16 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
|
|||||||
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||||
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
|
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
|
||||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
const block_iq2_xs * x = (const block_iq2_xs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
@ -286,33 +288,33 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
|
|||||||
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
|
||||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
|
||||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_iq2_s * x = (const block_iq2_s *) vx;
|
const block_iq2_s * x = (const block_iq2_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
|
||||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
|
||||||
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
|
||||||
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f));
|
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
@ -324,18 +326,18 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
|
|||||||
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
|
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
|
||||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||||
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_iq3_s * x = (const block_iq3_s *) vx;
|
const block_iq3_s * x = (const block_iq3_s *) vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
|
||||||
@ -345,8 +347,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
|
|||||||
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
|
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
|
||||||
const uint8_t signs = x[i].signs[4*ib + il];
|
const uint8_t signs = x[i].signs[4*ib + il];
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f));
|
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||||
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f));
|
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -367,7 +369,7 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
|
|||||||
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||||
grid32[0] &= 0x0f0f0f0f;
|
grid32[0] &= 0x0f0f0f0f;
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
y[j] = __float2half(d * (q[j] + delta));
|
y[j] = d * (q[j] + delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,43 +394,43 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
|
|||||||
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
|
||||||
grid32[0] &= 0x0f0f0f0f;
|
grid32[0] &= 0x0f0f0f0f;
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
y[j] = __float2half(d * (q[j] + delta));
|
y[j] = d * (q[j] + delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
|
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[ib].qs + 4*il;
|
const uint8_t * q4 = x[ib].qs + 4*il;
|
||||||
const float d = __half2float(x[ib].d);
|
const float d = __half2float(x[ib].d);
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||||
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
|
||||||
const int i = blockIdx.x;
|
const auto i = blockIdx.x;
|
||||||
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
const block_iq4_xs * x = (const block_iq4_xs *)vx;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const auto tid = threadIdx.x;
|
||||||
const int il = tid/8; // 0...3
|
const int il = tid/8; // 0...3
|
||||||
const int ib = tid%8; // 0...7
|
const int ib = tid%8; // 0...7
|
||||||
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
|
||||||
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
|
||||||
const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]);
|
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
|
||||||
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]);
|
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -522,7 +524,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
|
|||||||
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
template<typename dst_t>
|
||||||
|
static to_cuda_ggml_t<dst_t> ggml_get_to_cuda(int64_t type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case 2:
|
case 2:
|
||||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
||||||
@ -565,4 +568,4 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) {
|
|||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1063,7 +1063,8 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
|
|||||||
typedef half dfloat; // dequantize float
|
typedef half dfloat; // dequantize float
|
||||||
typedef half2 dfloat2;
|
typedef half2 dfloat2;
|
||||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||||
typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream);
|
template<typename dst_t>
|
||||||
|
using to_cuda_ggml_t = void (*)(const void * __restrict__ x, dst_t * __restrict__ y, int k, cudaStream_t stream);
|
||||||
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
|
||||||
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
|
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
|
||||||
typedef void (*load_tiles_cuda_t)(
|
typedef void (*load_tiles_cuda_t)(
|
||||||
@ -1075,6 +1076,25 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)(
|
|||||||
|
|
||||||
// Utility function
|
// Utility function
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __device__ __forceinline__ dst_t convert_from_half(half val) {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__device__ __forceinline__ c10::BFloat16 convert_from_half<c10::BFloat16>(half val) {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
return __float2bfloat16(__half2float(val));
|
||||||
|
#else
|
||||||
|
return __half2float(val);
|
||||||
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__device__ __forceinline__ float convert_from_half<float>(half val) {
|
||||||
|
return __half2float(val);
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
|
|
||||||
#ifndef __has_builtin
|
#ifndef __has_builtin
|
||||||
|
|||||||
@ -19,11 +19,11 @@ template <typename scalar_t>
|
|||||||
static __global__ void quantize_q8_1(const scalar_t* __restrict__ x,
|
static __global__ void quantize_q8_1(const scalar_t* __restrict__ x,
|
||||||
void* __restrict__ vy, const int kx,
|
void* __restrict__ vy, const int kx,
|
||||||
const int kx_padded) {
|
const int kx_padded) {
|
||||||
const int ix = blockDim.x * blockIdx.x + threadIdx.x;
|
const auto ix = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (ix >= kx_padded) {
|
if (ix >= kx_padded) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const int iy = blockDim.y * blockIdx.y + threadIdx.y;
|
const auto iy = blockDim.y * blockIdx.y + threadIdx.y;
|
||||||
const int i_padded = iy * kx_padded + ix;
|
const int i_padded = iy * kx_padded + ix;
|
||||||
|
|
||||||
block_q8_1* y = (block_q8_1*)vy;
|
block_q8_1* y = (block_q8_1*)vy;
|
||||||
@ -71,14 +71,19 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
|
||||||
int64_t type, int64_t m, int64_t n) {
|
int64_t type, int64_t m, int64_t n,
|
||||||
|
std::optional<at::ScalarType> const& dtype) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
|
||||||
auto options =
|
auto dtype_ = dtype.value_or(torch::kFloat16);
|
||||||
torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
|
auto options = torch::TensorOptions().dtype(dtype_).device(W.device());
|
||||||
at::Tensor DW = torch::empty({m, n}, options);
|
at::Tensor DW = torch::empty({m, n}, options);
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
|
|
||||||
to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream);
|
VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
|
||||||
|
auto to_cuda = ggml_get_to_cuda<scalar_t>(type);
|
||||||
|
to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream);
|
||||||
|
});
|
||||||
|
|
||||||
return DW;
|
return DW;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -375,25 +380,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
|
|||||||
int64_t ggml_moe_get_block_size(int64_t type) {
|
int64_t ggml_moe_get_block_size(int64_t type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case 2:
|
case 2:
|
||||||
return MMQ_X_Q4_0;
|
return MOE_X_Q4_0;
|
||||||
case 3:
|
case 3:
|
||||||
return MMQ_X_Q4_1;
|
return MOE_X_Q4_1;
|
||||||
case 6:
|
case 6:
|
||||||
return MMQ_X_Q5_0;
|
return MOE_X_Q5_0;
|
||||||
case 7:
|
case 7:
|
||||||
return MMQ_X_Q5_1;
|
return MOE_X_Q5_1;
|
||||||
case 8:
|
case 8:
|
||||||
return MMQ_X_Q8_0;
|
return MOE_X_Q8_0;
|
||||||
case 10:
|
case 10:
|
||||||
return MMQ_X_Q2_K;
|
return MOE_X_Q2_K;
|
||||||
case 11:
|
case 11:
|
||||||
return MMQ_X_Q3_K;
|
return MOE_X_Q3_K;
|
||||||
case 12:
|
case 12:
|
||||||
return MMQ_X_Q4_K;
|
return MOE_X_Q4_K;
|
||||||
case 13:
|
case 13:
|
||||||
return MMQ_X_Q5_K;
|
return MOE_X_Q5_K;
|
||||||
case 14:
|
case 14:
|
||||||
return MMQ_X_Q6_K;
|
return MOE_X_Q6_K;
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,10 +14,10 @@ static __device__ __forceinline__ void mul_mat_q(
|
|||||||
|
|
||||||
const int & ncols_dst = ncols_y;
|
const int & ncols_dst = ncols_y;
|
||||||
|
|
||||||
const int row_dst_0 = blockIdx.x*mmq_y;
|
const auto row_dst_0 = blockIdx.x*mmq_y;
|
||||||
const int & row_x_0 = row_dst_0;
|
const int & row_x_0 = row_dst_0;
|
||||||
|
|
||||||
const int col_dst_0 = blockIdx.y*mmq_x;
|
const auto col_dst_0 = blockIdx.y*mmq_x;
|
||||||
const int & col_y_0 = col_dst_0;
|
const int & col_y_0 = col_dst_0;
|
||||||
|
|
||||||
int * tile_x_ql = nullptr;
|
int * tile_x_ql = nullptr;
|
||||||
@ -39,7 +39,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) {
|
for (int ir = 0; ir < qr && ib0 + ir * blocks_per_warp/qr < blocks_per_row_x; ++ir) {
|
||||||
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
|
const auto kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
|
||||||
const int kbxd = kqs / QI8_1;
|
const int kbxd = kqs / QI8_1;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -53,7 +53,7 @@ static __device__ __forceinline__ void mul_mat_q(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
|
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
|
||||||
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
|
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
|
||||||
const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
|
const auto kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
|
||||||
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
|
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
|
||||||
|
|
||||||
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
|
// if the sum is not needed it's faster to transform the scale to f32 ahead of time
|
||||||
@ -87,14 +87,14 @@ static __device__ __forceinline__ void mul_mat_q(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < mmq_x; j += nwarps) {
|
for (int j = 0; j < mmq_x; j += nwarps) {
|
||||||
const int col_dst = col_dst_0 + j + threadIdx.y;
|
const auto col_dst = col_dst_0 + j + threadIdx.y;
|
||||||
if (col_dst >= ncols_dst) {
|
if (col_dst >= ncols_dst) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||||
const int row_dst = row_dst_0 + threadIdx.x + i;
|
const auto row_dst = row_dst_0 + threadIdx.x + i;
|
||||||
if (row_dst >= nrows_dst) {
|
if (row_dst >= nrows_dst) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
// copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu
|
||||||
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
|
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) {
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const auto row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
if (row >= nrows) {
|
if (row >= nrows) {
|
||||||
return;
|
return;
|
||||||
@ -16,7 +16,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
|
|||||||
const block_q_t * x = (const block_q_t *) vx;
|
const block_q_t * x = (const block_q_t *) vx;
|
||||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||||
|
|
||||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
||||||
const int ibx = row*blocks_per_row + i; // x block index
|
const int ibx = row*blocks_per_row + i; // x block index
|
||||||
|
|
||||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||||
|
|||||||
@ -19,10 +19,10 @@ static __device__ __forceinline__ void moe_q(
|
|||||||
|
|
||||||
const int ncols_dst = ncols_y * top_k;
|
const int ncols_dst = ncols_y * top_k;
|
||||||
|
|
||||||
const int row_dst_0 = blockIdx.x * mmq_y;
|
const auto row_dst_0 = blockIdx.x * mmq_y;
|
||||||
const int& row_x_0 = row_dst_0;
|
const int& row_x_0 = row_dst_0;
|
||||||
|
|
||||||
const int col_dst_0 = blockIdx.y * mmq_x;
|
const auto col_dst_0 = blockIdx.y * mmq_x;
|
||||||
|
|
||||||
int token_offs[mmq_x / nwarps];
|
int token_offs[mmq_x / nwarps];
|
||||||
for (int i = 0; i < mmq_x; i += nwarps) {
|
for (int i = 0; i < mmq_x; i += nwarps) {
|
||||||
@ -56,7 +56,7 @@ static __device__ __forceinline__ void moe_q(
|
|||||||
const int n_per_r = ((qk * blocks_per_warp) / qr);
|
const int n_per_r = ((qk * blocks_per_warp) / qr);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ir = 0; ir < qr && ib0 * qk + ir * n_per_r < ncols_x; ++ir) {
|
for (int ir = 0; ir < qr && ib0 * qk + ir * n_per_r < ncols_x; ++ir) {
|
||||||
const int kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
|
const auto kqs = ir * WARP_SIZE_GGUF + threadIdx.x;
|
||||||
const int kbxd = kqs / QI8_1;
|
const int kbxd = kqs / QI8_1;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -73,7 +73,7 @@ static __device__ __forceinline__ void moe_q(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x < n_per_r / QK8_1) {
|
if (threadIdx.x < n_per_r / QK8_1) {
|
||||||
const int kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
|
const auto kby = threadIdx.x % (WARP_SIZE_GGUF / QI8_1);
|
||||||
const int col_y_eff = token_offs[threadIdx.y] / top_k;
|
const int col_y_eff = token_offs[threadIdx.y] / top_k;
|
||||||
const int block_x =
|
const int block_x =
|
||||||
ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby;
|
ib0 * (qk / QK8_1) + ir * (WARP_SIZE_GGUF / QI8_1) + kby;
|
||||||
@ -119,7 +119,7 @@ static __device__ __forceinline__ void moe_q(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
|
||||||
const int row_dst = row_dst_0 + threadIdx.x + i;
|
const auto row_dst = row_dst_0 + threadIdx.x + i;
|
||||||
if (row_dst >= nrows_dst) {
|
if (row_dst >= nrows_dst) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q4_0 64
|
#define MOE_X_Q4_0 8
|
||||||
#define MMQ_Y_Q4_0 128
|
#define MOE_Y_Q4_0 128
|
||||||
#define NWARPS_Q4_0 8
|
#define NWARPS_Q4_0 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q4_0 4
|
#define MOE_X_Q4_0 4
|
||||||
#define MMQ_Y_Q4_0 32
|
#define MOE_Y_Q4_0 32
|
||||||
#define NWARPS_Q4_0 4
|
#define NWARPS_Q4_0 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q4_0;
|
const int mmq_x = MOE_X_Q4_0;
|
||||||
const int mmq_y = MMQ_Y_Q4_0;
|
const int mmq_y = MOE_Y_Q4_0;
|
||||||
const int nwarps = NWARPS_Q4_0;
|
const int nwarps = NWARPS_Q4_0;
|
||||||
|
|
||||||
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
|
||||||
@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
int mmq_x = MMQ_X_Q4_0;
|
int mmq_x = MOE_X_Q4_0;
|
||||||
int mmq_y = MMQ_Y_Q4_0;
|
int mmq_y = MOE_Y_Q4_0;
|
||||||
int nwarps = NWARPS_Q4_0;
|
int nwarps = NWARPS_Q4_0;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q4_1 64
|
#define MOE_X_Q4_1 8
|
||||||
#define MMQ_Y_Q4_1 128
|
#define MOE_Y_Q4_1 128
|
||||||
#define NWARPS_Q4_1 8
|
#define NWARPS_Q4_1 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q4_1 4
|
#define MOE_X_Q4_1 4
|
||||||
#define MMQ_Y_Q4_1 32
|
#define MOE_Y_Q4_1 32
|
||||||
#define NWARPS_Q4_1 4
|
#define NWARPS_Q4_1 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q4_1;
|
const int mmq_x = MOE_X_Q4_1;
|
||||||
const int mmq_y = MMQ_Y_Q4_1;
|
const int mmq_y = MOE_Y_Q4_1;
|
||||||
const int nwarps = NWARPS_Q4_1;
|
const int nwarps = NWARPS_Q4_1;
|
||||||
|
|
||||||
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
|
||||||
@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
int mmq_x = MMQ_X_Q4_1;
|
int mmq_x = MOE_X_Q4_1;
|
||||||
int mmq_y = MMQ_Y_Q4_1;
|
int mmq_y = MOE_Y_Q4_1;
|
||||||
int nwarps = NWARPS_Q4_1;
|
int nwarps = NWARPS_Q4_1;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q5_0 64
|
#define MOE_X_Q5_0 8
|
||||||
#define MMQ_Y_Q5_0 128
|
#define MOE_Y_Q5_0 128
|
||||||
#define NWARPS_Q5_0 8
|
#define NWARPS_Q5_0 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q5_0 4
|
#define MOE_X_Q5_0 4
|
||||||
#define MMQ_Y_Q5_0 32
|
#define MOE_Y_Q5_0 32
|
||||||
#define NWARPS_Q5_0 4
|
#define NWARPS_Q5_0 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q5_0;
|
const int mmq_x = MOE_X_Q5_0;
|
||||||
const int mmq_y = MMQ_Y_Q5_0;
|
const int mmq_y = MOE_Y_Q5_0;
|
||||||
const int nwarps = NWARPS_Q5_0;
|
const int nwarps = NWARPS_Q5_0;
|
||||||
|
|
||||||
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
|
||||||
@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q5_0;
|
const int mmq_x = MOE_X_Q5_0;
|
||||||
const int mmq_y = MMQ_Y_Q5_0;
|
const int mmq_y = MOE_Y_Q5_0;
|
||||||
const int nwarps = NWARPS_Q5_0;
|
const int nwarps = NWARPS_Q5_0;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q5_1 64
|
#define MOE_X_Q5_1 8
|
||||||
#define MMQ_Y_Q5_1 128
|
#define MOE_Y_Q5_1 128
|
||||||
#define NWARPS_Q5_1 8
|
#define NWARPS_Q5_1 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q5_1 4
|
#define MOE_X_Q5_1 4
|
||||||
#define MMQ_Y_Q5_1 32
|
#define MOE_Y_Q5_1 32
|
||||||
#define NWARPS_Q5_1 4
|
#define NWARPS_Q5_1 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q5_1;
|
const int mmq_x = MOE_X_Q5_1;
|
||||||
const int mmq_y = MMQ_Y_Q5_1;
|
const int mmq_y = MOE_Y_Q5_1;
|
||||||
const int nwarps = NWARPS_Q5_1;
|
const int nwarps = NWARPS_Q5_1;
|
||||||
|
|
||||||
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
|
||||||
@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q5_1;
|
const int mmq_x = MOE_X_Q5_1;
|
||||||
const int mmq_y = MMQ_Y_Q5_1;
|
const int mmq_y = MOE_Y_Q5_1;
|
||||||
const int nwarps = NWARPS_Q5_1;
|
const int nwarps = NWARPS_Q5_1;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q8_0 64
|
#define MOE_X_Q8_0 8
|
||||||
#define MMQ_Y_Q8_0 128
|
#define MOE_Y_Q8_0 128
|
||||||
#define NWARPS_Q8_0 8
|
#define NWARPS_Q8_0 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q8_0 4
|
#define MOE_X_Q8_0 4
|
||||||
#define MMQ_Y_Q8_0 32
|
#define MOE_Y_Q8_0 32
|
||||||
#define NWARPS_Q8_0 4
|
#define NWARPS_Q8_0 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q8_0;
|
const int mmq_x = MOE_X_Q8_0;
|
||||||
const int mmq_y = MMQ_Y_Q8_0;
|
const int mmq_y = MOE_Y_Q8_0;
|
||||||
const int nwarps = NWARPS_Q8_0;
|
const int nwarps = NWARPS_Q8_0;
|
||||||
|
|
||||||
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
|
||||||
@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q8_0;
|
const int mmq_x = MOE_X_Q8_0;
|
||||||
const int mmq_y = MMQ_Y_Q8_0;
|
const int mmq_y = MOE_Y_Q8_0;
|
||||||
const int nwarps = NWARPS_Q8_0;
|
const int nwarps = NWARPS_Q8_0;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q2_K 64
|
#define MOE_X_Q2_K 8
|
||||||
#define MMQ_Y_Q2_K 128
|
#define MOE_Y_Q2_K 128
|
||||||
#define NWARPS_Q2_K 8
|
#define NWARPS_Q2_K 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q2_K 4
|
#define MOE_X_Q2_K 4
|
||||||
#define MMQ_Y_Q2_K 32
|
#define MOE_Y_Q2_K 32
|
||||||
#define NWARPS_Q2_K 4
|
#define NWARPS_Q2_K 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q2_K;
|
const int mmq_x = MOE_X_Q2_K;
|
||||||
const int mmq_y = MMQ_Y_Q2_K;
|
const int mmq_y = MOE_Y_Q2_K;
|
||||||
const int nwarps = NWARPS_Q2_K;
|
const int nwarps = NWARPS_Q2_K;
|
||||||
|
|
||||||
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
|
||||||
@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q2_K;
|
const int mmq_x = MOE_X_Q2_K;
|
||||||
const int mmq_y = MMQ_Y_Q2_K;
|
const int mmq_y = MOE_Y_Q2_K;
|
||||||
const int nwarps = NWARPS_Q2_K;
|
const int nwarps = NWARPS_Q2_K;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q3_K 64
|
#define MOE_X_Q3_K 8
|
||||||
#define MMQ_Y_Q3_K 128
|
#define MOE_Y_Q3_K 128
|
||||||
#define NWARPS_Q3_K 8
|
#define NWARPS_Q3_K 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q3_K 4
|
#define MOE_X_Q3_K 4
|
||||||
#define MMQ_Y_Q3_K 32
|
#define MOE_Y_Q3_K 32
|
||||||
#define NWARPS_Q3_K 4
|
#define NWARPS_Q3_K 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
|
|||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
|
|
||||||
const int mmq_x = MMQ_X_Q3_K;
|
const int mmq_x = MOE_X_Q3_K;
|
||||||
const int mmq_y = MMQ_Y_Q3_K;
|
const int mmq_y = MOE_Y_Q3_K;
|
||||||
const int nwarps = NWARPS_Q3_K;
|
const int nwarps = NWARPS_Q3_K;
|
||||||
|
|
||||||
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
|
||||||
@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q3_K;
|
const int mmq_x = MOE_X_Q3_K;
|
||||||
const int mmq_y = MMQ_Y_Q3_K;
|
const int mmq_y = MOE_Y_Q3_K;
|
||||||
const int nwarps = NWARPS_Q3_K;
|
const int nwarps = NWARPS_Q3_K;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q4_K 64
|
#define MOE_X_Q4_K 8
|
||||||
#define MMQ_Y_Q4_K 128
|
#define MOE_Y_Q4_K 128
|
||||||
#define NWARPS_Q4_K 8
|
#define NWARPS_Q4_K 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q4_K 4
|
#define MOE_X_Q4_K 4
|
||||||
#define MMQ_Y_Q4_K 32
|
#define MOE_Y_Q4_K 32
|
||||||
#define NWARPS_Q4_K 4
|
#define NWARPS_Q4_K 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q4_K;
|
const int mmq_x = MOE_X_Q4_K;
|
||||||
const int mmq_y = MMQ_Y_Q4_K;
|
const int mmq_y = MOE_Y_Q4_K;
|
||||||
const int nwarps = NWARPS_Q4_K;
|
const int nwarps = NWARPS_Q4_K;
|
||||||
|
|
||||||
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
|
||||||
@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q4_K;
|
const int mmq_x = MOE_X_Q4_K;
|
||||||
const int mmq_y = MMQ_Y_Q4_K;
|
const int mmq_y = MOE_Y_Q4_K;
|
||||||
const int nwarps = NWARPS_Q4_K;
|
const int nwarps = NWARPS_Q4_K;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q5_K 64
|
#define MOE_X_Q5_K 8
|
||||||
#define MMQ_Y_Q5_K 128
|
#define MOE_Y_Q5_K 128
|
||||||
#define NWARPS_Q5_K 8
|
#define NWARPS_Q5_K 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q5_K 4
|
#define MOE_X_Q5_K 4
|
||||||
#define MMQ_Y_Q5_K 32
|
#define MOE_Y_Q5_K 32
|
||||||
#define NWARPS_Q5_K 4
|
#define NWARPS_Q5_K 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q5_K;
|
const int mmq_x = MOE_X_Q5_K;
|
||||||
const int mmq_y = MMQ_Y_Q5_K;
|
const int mmq_y = MOE_Y_Q5_K;
|
||||||
const int nwarps = NWARPS_Q5_K;
|
const int nwarps = NWARPS_Q5_K;
|
||||||
|
|
||||||
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
|
||||||
@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q5_K;
|
const int mmq_x = MOE_X_Q5_K;
|
||||||
const int mmq_y = MMQ_Y_Q5_K;
|
const int mmq_y = MOE_Y_Q5_K;
|
||||||
const int nwarps = NWARPS_Q5_K;
|
const int nwarps = NWARPS_Q5_K;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(USE_ROCM)
|
#if defined(USE_ROCM)
|
||||||
#define MMQ_X_Q6_K 64
|
#define MOE_X_Q6_K 8
|
||||||
#define MMQ_Y_Q6_K 128
|
#define MOE_Y_Q6_K 128
|
||||||
#define NWARPS_Q6_K 8
|
#define NWARPS_Q6_K 8
|
||||||
#else
|
#else
|
||||||
#define MMQ_X_Q6_K 4
|
#define MOE_X_Q6_K 4
|
||||||
#define MMQ_Y_Q6_K 32
|
#define MOE_Y_Q6_K 32
|
||||||
#define NWARPS_Q6_K 4
|
#define NWARPS_Q6_K 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst,
|
const int ncols_y, const int nrows_y, const int nrows_dst,
|
||||||
const int top_k) {
|
const int top_k) {
|
||||||
const int mmq_x = MMQ_X_Q6_K;
|
const int mmq_x = MOE_X_Q6_K;
|
||||||
const int mmq_y = MMQ_Y_Q6_K;
|
const int mmq_y = MOE_Y_Q6_K;
|
||||||
const int nwarps = NWARPS_Q6_K;
|
const int nwarps = NWARPS_Q6_K;
|
||||||
|
|
||||||
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
|
moe_q<scalar_t, QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
|
||||||
@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda(
|
|||||||
const int exp_stride, const int ncols_x, const int nrows_x,
|
const int exp_stride, const int ncols_x, const int nrows_x,
|
||||||
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
const int ncols_y, const int nrows_y, const int nrows_dst, const int top_k,
|
||||||
const int tokens_post_padded, cudaStream_t stream) {
|
const int tokens_post_padded, cudaStream_t stream) {
|
||||||
const int mmq_x = MMQ_X_Q6_K;
|
const int mmq_x = MOE_X_Q6_K;
|
||||||
const int mmq_y = MMQ_Y_Q6_K;
|
const int mmq_y = MOE_Y_Q6_K;
|
||||||
const int nwarps = NWARPS_Q6_K;
|
const int nwarps = NWARPS_Q6_K;
|
||||||
|
|
||||||
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
|
||||||
|
|||||||
@ -199,12 +199,12 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel(
|
|||||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
// Block
|
// Block
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
auto offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||||
|
|
||||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||||
@ -337,12 +337,12 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel(
|
|||||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
// Block
|
// Block
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
auto offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||||
|
|
||||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||||
@ -458,12 +458,12 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel(
|
|||||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
// Block
|
// Block
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
auto offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||||
|
|
||||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||||
@ -586,12 +586,12 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel(
|
|||||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
// Block
|
// Block
|
||||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
auto offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||||
int offset_m = blockIdx.y * m_count;
|
auto offset_m = blockIdx.y * m_count;
|
||||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
auto offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||||
|
|
||||||
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
[[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||||
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
[[maybe_unused]] int end_m = min(offset_m + m_count, size_m);
|
||||||
@ -765,14 +765,14 @@ __global__ void reconstruct_exllama_8bit_kernel(
|
|||||||
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||||
|
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
// Preload remapping table
|
// Preload remapping table
|
||||||
__shared__ int perm[BLOCK_KN_SIZE];
|
__shared__ int perm[BLOCK_KN_SIZE];
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
if (b_q_perm) {
|
if (b_q_perm) {
|
||||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||||
@ -862,14 +862,14 @@ __global__ void reconstruct_exllama_4bit_kernel(
|
|||||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||||
|
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
// Preload remapping table
|
// Preload remapping table
|
||||||
__shared__ int perm[BLOCK_KN_SIZE];
|
__shared__ int perm[BLOCK_KN_SIZE];
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
if (b_q_perm) {
|
if (b_q_perm) {
|
||||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||||
@ -967,14 +967,14 @@ __global__ void reconstruct_exllama_3bit_kernel(
|
|||||||
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||||
|
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
// Preload remapping table
|
// Preload remapping table
|
||||||
__shared__ int perm[BLOCK_KN_SIZE];
|
__shared__ int perm[BLOCK_KN_SIZE];
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
if (b_q_perm) {
|
if (b_q_perm) {
|
||||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||||
@ -1065,14 +1065,14 @@ __global__ void reconstruct_exllama_2bit_kernel(
|
|||||||
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
auto offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
auto offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||||
|
|
||||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
// Preload remapping table
|
// Preload remapping table
|
||||||
__shared__ int perm[BLOCK_KN_SIZE];
|
__shared__ int perm[BLOCK_KN_SIZE];
|
||||||
int t = threadIdx.x;
|
auto t = threadIdx.x;
|
||||||
|
|
||||||
if (b_q_perm) {
|
if (b_q_perm) {
|
||||||
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t];
|
||||||
@ -1181,11 +1181,11 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
|||||||
int zero_width = width / 8;
|
int zero_width = width / 8;
|
||||||
int vec_height = height * 4;
|
int vec_height = height * 4;
|
||||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||||
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
auto h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||||
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
||||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||||
if (threadIdx.x < h_end) {
|
if (threadIdx.x < h_end) {
|
||||||
@ -1197,8 +1197,8 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
__shared__ half2 deq2[256][8];
|
__shared__ half2 deq2[256][8];
|
||||||
int val = threadIdx.x / 8;
|
auto val = threadIdx.x / 8;
|
||||||
int off = threadIdx.x % 8;
|
auto off = threadIdx.x % 8;
|
||||||
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
|
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
|
||||||
deq2[val][off] =
|
deq2[val][off] =
|
||||||
__halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
|
__halves2half2(__int2half_rn(val & 0xF), __int2half_rn(val >> 4));
|
||||||
@ -1280,11 +1280,11 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
|
|||||||
int zero_width = width / 4;
|
int zero_width = width / 4;
|
||||||
int vec_height = height * 2;
|
int vec_height = height * 2;
|
||||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
auto b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||||
int h = BLOCK_KN_SIZE * blockIdx.z / 4;
|
auto h = BLOCK_KN_SIZE * blockIdx.z / 4;
|
||||||
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
|
int h_end = min(BLOCK_KN_SIZE / 4, height - h) * 2;
|
||||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
auto w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||||
if (threadIdx.x < h_end) {
|
if (threadIdx.x < h_end) {
|
||||||
@ -1393,8 +1393,8 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w,
|
|||||||
half* __restrict__ out) {
|
half* __restrict__ out) {
|
||||||
// Start of block
|
// Start of block
|
||||||
|
|
||||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||||
int row = blockIdx.y * 32 / bit;
|
auto row = blockIdx.y * 32 / bit;
|
||||||
if (column >= width) return;
|
if (column >= width) return;
|
||||||
|
|
||||||
// Views
|
// Views
|
||||||
@ -1425,8 +1425,8 @@ __global__ void reconstruct_gptq_3bit_kernel(
|
|||||||
const int height, const int width, const int group,
|
const int height, const int width, const int group,
|
||||||
half* __restrict__ out) {
|
half* __restrict__ out) {
|
||||||
// Start of block
|
// Start of block
|
||||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
auto column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||||
int row = blockIdx.y * 32;
|
auto row = blockIdx.y * 32;
|
||||||
if (column >= width) return;
|
if (column >= width) return;
|
||||||
|
|
||||||
// Views
|
// Views
|
||||||
@ -1542,7 +1542,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a,
|
|||||||
|
|
||||||
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
|
__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||||
const int size_k, const int size_n) {
|
const int size_k, const int size_n) {
|
||||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||||
if (n >= size_n) return;
|
if (n >= size_n) return;
|
||||||
int k = 0;
|
int k = 0;
|
||||||
uint32_t* b_ptr = b_q_weight + n;
|
uint32_t* b_ptr = b_q_weight + n;
|
||||||
@ -1555,7 +1555,7 @@ __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight,
|
|||||||
|
|
||||||
__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
|
__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||||
const int size_k, const int size_n) {
|
const int size_k, const int size_n) {
|
||||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||||
if (n >= size_n) return;
|
if (n >= size_n) return;
|
||||||
int k = 0;
|
int k = 0;
|
||||||
uint32_t* b_ptr = b_q_weight + n;
|
uint32_t* b_ptr = b_q_weight + n;
|
||||||
@ -1568,7 +1568,7 @@ __global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight,
|
|||||||
|
|
||||||
__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
|
__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||||
const int size_k, const int size_n) {
|
const int size_k, const int size_n) {
|
||||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||||
if (n >= size_n) return;
|
if (n >= size_n) return;
|
||||||
int k = 0;
|
int k = 0;
|
||||||
uint32_t* b_ptr = b_q_weight + n;
|
uint32_t* b_ptr = b_q_weight + n;
|
||||||
@ -1581,7 +1581,7 @@ __global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight,
|
|||||||
|
|
||||||
__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight,
|
__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight,
|
||||||
const int size_k, const int size_n) {
|
const int size_k, const int size_n) {
|
||||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
auto n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||||
if (n >= size_n) return;
|
if (n >= size_n) return;
|
||||||
int k = 0;
|
int k = 0;
|
||||||
uint32_t* b_ptr = b_q_weight + n;
|
uint32_t* b_ptr = b_q_weight + n;
|
||||||
@ -1599,9 +1599,9 @@ __global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w,
|
|||||||
const uint64_t* w2 = (uint64_t*)w;
|
const uint64_t* w2 = (uint64_t*)w;
|
||||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||||
int w2_stride = w_width >> 1;
|
int w2_stride = w_width >> 1;
|
||||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
if (w2_column >= w2_stride) return;
|
if (w2_column >= w2_stride) return;
|
||||||
int w_new2_row = blockIdx.y;
|
auto w_new2_row = blockIdx.y;
|
||||||
int q_perm_idx = w_new2_row << 3;
|
int q_perm_idx = w_new2_row << 3;
|
||||||
uint64_t dst = 0;
|
uint64_t dst = 0;
|
||||||
|
|
||||||
@ -1630,9 +1630,9 @@ __global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w,
|
|||||||
const uint64_t* w2 = (uint64_t*)w;
|
const uint64_t* w2 = (uint64_t*)w;
|
||||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||||
int w2_stride = w_width >> 1;
|
int w2_stride = w_width >> 1;
|
||||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
if (w2_column >= w2_stride) return;
|
if (w2_column >= w2_stride) return;
|
||||||
int w_new2_row = blockIdx.y;
|
auto w_new2_row = blockIdx.y;
|
||||||
int q_perm_idx = w_new2_row << 4;
|
int q_perm_idx = w_new2_row << 4;
|
||||||
uint64_t dst = 0;
|
uint64_t dst = 0;
|
||||||
|
|
||||||
@ -1658,10 +1658,10 @@ __global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w,
|
|||||||
uint32_t* __restrict__ w_new,
|
uint32_t* __restrict__ w_new,
|
||||||
const int* __restrict__ q_perm,
|
const int* __restrict__ q_perm,
|
||||||
const int w_width) {
|
const int w_width) {
|
||||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
auto w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
if (w_column >= w_width) return;
|
if (w_column >= w_width) return;
|
||||||
int w_new_row = blockIdx.y * 3;
|
auto w_new_row = blockIdx.y * 3;
|
||||||
int q_perm_idx = blockIdx.y << 5;
|
auto q_perm_idx = blockIdx.y << 5;
|
||||||
uint32_t dst[3] = {0, 0, 0};
|
uint32_t dst[3] = {0, 0, 0};
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -1744,9 +1744,9 @@ __global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w,
|
|||||||
const uint64_t* w2 = (uint64_t*)w;
|
const uint64_t* w2 = (uint64_t*)w;
|
||||||
uint64_t* w_new2 = (uint64_t*)w_new;
|
uint64_t* w_new2 = (uint64_t*)w_new;
|
||||||
int w2_stride = w_width >> 1;
|
int w2_stride = w_width >> 1;
|
||||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
auto w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
if (w2_column >= w2_stride) return;
|
if (w2_column >= w2_stride) return;
|
||||||
int w_new2_row = blockIdx.y;
|
auto w_new2_row = blockIdx.y;
|
||||||
int q_perm_idx = w_new2_row << 2;
|
int q_perm_idx = w_new2_row << 2;
|
||||||
uint64_t dst = 0;
|
uint64_t dst = 0;
|
||||||
|
|
||||||
|
|||||||
@ -55,11 +55,11 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
this_block_B_base_ptr = params.B_ptr + blockIdx.y * Ntile * params.K +
|
this_block_B_base_ptr = params.B_ptr + blockIdx.y * Ntile * params.K +
|
||||||
blockIdx.z * params.SplitK * 4;
|
blockIdx.z * params.SplitK * 4;
|
||||||
|
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const auto lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
// For matrix A, a block load/store Mtile(row) x 32(col) elements in
|
// For matrix A, a block load/store Mtile(row) x 32(col) elements in
|
||||||
// multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter
|
// multiple iters, 8x4 warp load/store 8(row) x 32(col) elements per iter
|
||||||
const int Aldg_row_base_idx = threadIdx.x / 4;
|
const auto Aldg_row_base_idx = threadIdx.x / 4;
|
||||||
Aldg_col_idx = (threadIdx.x % 4) * LDG_ELEMENT_CNT_A;
|
Aldg_col_idx = (threadIdx.x % 4) * LDG_ELEMENT_CNT_A;
|
||||||
const int Aldg_base_offset = Aldg_row_base_idx * params.K + Aldg_col_idx;
|
const int Aldg_base_offset = Aldg_row_base_idx * params.K + Aldg_col_idx;
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
// elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row)
|
// elements of N32K16 packing in multiple iters, 4x8 warp load/store 4(row)
|
||||||
// * 128(col) per iter
|
// * 128(col) per iter
|
||||||
Bldg_col_idx = (threadIdx.x % 8) * LDG_ELEMENT_CNT_B;
|
Bldg_col_idx = (threadIdx.x % 8) * LDG_ELEMENT_CNT_B;
|
||||||
const int Bldg_row_base_idx = threadIdx.x / 8;
|
const auto Bldg_row_base_idx = threadIdx.x / 8;
|
||||||
const int Bldg_base_offset =
|
const int Bldg_base_offset =
|
||||||
Bldg_row_base_idx * params.K * 4 + Bldg_col_idx;
|
Bldg_row_base_idx * params.K * 4 + Bldg_col_idx;
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
B_ldg_guard = 0;
|
B_ldg_guard = 0;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) {
|
for (int i = 0; i < (Mtile + M_SIZE_ONE_LOAD - 1) / M_SIZE_ONE_LOAD; ++i) {
|
||||||
int m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
auto m_idx = blockIdx.x * Mtile + Aldg_row_base_idx + i * M_SIZE_ONE_LOAD;
|
||||||
if (m_idx < params.M) {
|
if (m_idx < params.M) {
|
||||||
A_ldg_guard |= (1u << i);
|
A_ldg_guard |= (1u << i);
|
||||||
}
|
}
|
||||||
@ -98,8 +98,8 @@ struct GmemTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
const int N_padded = (params.N + 31) / 32 * 32;
|
const int N_padded = (params.N + 31) / 32 * 32;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) {
|
for (int i = 0; i < (Ntile + N_SIZE_ONE_LOAD - 1) / N_SIZE_ONE_LOAD; ++i) {
|
||||||
int n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
auto n_idx = blockIdx.y * Ntile + (Bldg_row_base_idx / 8) * 32 +
|
||||||
i * N_SIZE_ONE_LOAD;
|
i * N_SIZE_ONE_LOAD;
|
||||||
if (n_idx < N_padded) {
|
if (n_idx < N_padded) {
|
||||||
B_ldg_guard |= (1u << i);
|
B_ldg_guard |= (1u << i);
|
||||||
}
|
}
|
||||||
@ -355,7 +355,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
__device__ void fused_splitk_reduce() {
|
__device__ void fused_splitk_reduce() {
|
||||||
// need splitk-reduce if enable splitk
|
// need splitk-reduce if enable splitk
|
||||||
if (gridDim.z > 1) {
|
if (gridDim.z > 1) {
|
||||||
int blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
auto blk_red_idx = blockIdx.x * gridDim.y + blockIdx.y;
|
||||||
// Wait for all previous blocks in the splitk direction to accumulate the
|
// Wait for all previous blocks in the splitk direction to accumulate the
|
||||||
// results into C_tmp
|
// results into C_tmp
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
@ -371,7 +371,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
int C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
auto C_tmp_base_offset = blk_red_idx * Mtile * Ntile + threadIdx.x * 4;
|
||||||
if (blockIdx.z != 0) {
|
if (blockIdx.z != 0) {
|
||||||
// expecting that temporary register here reuses the previous A&B frag
|
// expecting that temporary register here reuses the previous A&B frag
|
||||||
// register
|
// register
|
||||||
@ -456,7 +456,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
|
|||||||
|
|
||||||
FType* C_base_ptr = this_block_C_base_ptr + store_c_base_offset;
|
FType* C_base_ptr = this_block_C_base_ptr + store_c_base_offset;
|
||||||
// C_tile lds and stg
|
// C_tile lds and stg
|
||||||
int m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
auto m_base_idx = store_c_row_base_idx + blockIdx.x * Mtile;
|
||||||
bool n_guard = (store_c_col_idx + blockIdx.y * Ntile) < params.N;
|
bool n_guard = (store_c_col_idx + blockIdx.y * Ntile) < params.N;
|
||||||
if (WARP_NTILE == 32) {
|
if (WARP_NTILE == 32) {
|
||||||
int lds_c_base_offset = warp_id * Mtile * WARP_NTILE +
|
int lds_c_base_offset = warp_id * Mtile * WARP_NTILE +
|
||||||
@ -580,9 +580,9 @@ __global__ void __launch_bounds__(BLOCK)
|
|||||||
int sts_stage_idx = 0;
|
int sts_stage_idx = 0;
|
||||||
int lds_stage_idx = 0;
|
int lds_stage_idx = 0;
|
||||||
|
|
||||||
int tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
auto tb_k_slice = blockIdx.z * params.SplitK + params.SplitK <= params.K
|
||||||
? params.SplitK
|
? params.SplitK
|
||||||
: params.K - blockIdx.z * params.SplitK;
|
: params.K - blockIdx.z * params.SplitK;
|
||||||
int k_tiles = (tb_k_slice + 31) / 32;
|
int k_tiles = (tb_k_slice + 31) / 32;
|
||||||
int first_k_tile = tb_k_slice - (k_tiles - 1) * 32;
|
int first_k_tile = tb_k_slice - (k_tiles - 1) * 32;
|
||||||
|
|
||||||
@ -777,13 +777,13 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
|||||||
const QT* qdata, const FT* scales, const FT* zeros, FT* fdata,
|
const QT* qdata, const FT* scales, const FT* zeros, FT* fdata,
|
||||||
const int N_32align, const int N, const int K) {
|
const int N_32align, const int N, const int K) {
|
||||||
__shared__ FT smem[64 * 32];
|
__shared__ FT smem[64 * 32];
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int lane_id = threadIdx.x % 32;
|
auto lane_id = threadIdx.x % 32;
|
||||||
const int src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
const auto src_row_idx = blockIdx.x * 8 + lane_id / 4;
|
||||||
const int src_col_idx =
|
const int src_col_idx =
|
||||||
blockIdx.y * 64 * 4 + warp_id * 16 * 4 + (lane_id % 4) * 16;
|
blockIdx.y * 64 * 4 + warp_id * 16 * 4 + (lane_id % 4) * 16;
|
||||||
const int src_offset = src_row_idx * K * 4 + src_col_idx;
|
const int src_offset = src_row_idx * K * 4 + src_col_idx;
|
||||||
int params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
auto params_nidx = blockIdx.x * 32 + (lane_id / 4) * 4;
|
||||||
|
|
||||||
QT qval_reg[16];
|
QT qval_reg[16];
|
||||||
const QT* pdata = qdata + src_offset;
|
const QT* pdata = qdata + src_offset;
|
||||||
@ -829,8 +829,8 @@ __global__ void restore_N32_K16_dequantize_rhs_w8a16_perc_kernel(
|
|||||||
*reinterpret_cast<uint4*>(smem + lds_base_offset + i * 32 * 32);
|
*reinterpret_cast<uint4*>(smem + lds_base_offset + i * 32 * 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
const auto dst_row_base_kidx = blockIdx.y * 64 + threadIdx.x / 4;
|
||||||
const int dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
const auto dst_col_nidx = blockIdx.x * 32 + (threadIdx.x % 4) * 8;
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
int dst_row_kidx = dst_row_base_kidx + i * 32;
|
int dst_row_kidx = dst_row_base_kidx + i * 32;
|
||||||
@ -1008,4 +1008,4 @@ torch::Tensor allspark_w8a16_gemm(
|
|||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm);
|
m.impl("allspark_w8a16_gemm", &allspark_w8a16_gemm);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,8 +13,8 @@ __global__ void __launch_bounds__(128)
|
|||||||
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
const uint8_t* B, const FType* B_scale, const FType* B_zero,
|
||||||
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
uint8_t* B_result, FType* B_scale_result, FType* B_zero_result,
|
||||||
const int K, const int N, const int N_32align) {
|
const int K, const int N, const int N_32align) {
|
||||||
const int lane_id = threadIdx.x % 32;
|
const auto lane_id = threadIdx.x % 32;
|
||||||
const int warp_id = threadIdx.x / 32;
|
const auto warp_id = threadIdx.x / 32;
|
||||||
|
|
||||||
if (blockIdx.x != gridDim.x - 1) {
|
if (blockIdx.x != gridDim.x - 1) {
|
||||||
// Load B
|
// Load B
|
||||||
@ -50,7 +50,7 @@ __global__ void __launch_bounds__(128)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store B
|
// Store B
|
||||||
const int dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
const auto dst_row_base_idx = blockIdx.y * (128 / 4) + (lane_id / 8) * 8;
|
||||||
const int dst_col_idx =
|
const int dst_col_idx =
|
||||||
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
|
blockIdx.x * (64 * 4) + warp_id * 64 + (lane_id % 8) * 8;
|
||||||
for (int i = 0; i < 8; ++i) {
|
for (int i = 0; i < 8; ++i) {
|
||||||
@ -65,7 +65,7 @@ __global__ void __launch_bounds__(128)
|
|||||||
} else {
|
} else {
|
||||||
// Load B_scale and B_zero
|
// Load B_scale and B_zero
|
||||||
FType b_scale_reg, b_zero_reg;
|
FType b_scale_reg, b_zero_reg;
|
||||||
int src_offset = blockIdx.y * 128 + threadIdx.x;
|
auto src_offset = blockIdx.y * 128 + threadIdx.x;
|
||||||
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
|
ldg16_cg_0(b_scale_reg, B_scale + src_offset, src_offset < N);
|
||||||
if (B_zero != nullptr)
|
if (B_zero != nullptr)
|
||||||
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
|
ldg16_cg_0(b_zero_reg, B_zero + src_offset, src_offset < N);
|
||||||
|
|||||||
@ -62,7 +62,7 @@ template <typename FType, int BLOCK, int N_MATRIX>
|
|||||||
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
__global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||||
uint32_t n, uint32_t n_matrix,
|
uint32_t n, uint32_t n_matrix,
|
||||||
uint32_t matrix_size) {
|
uint32_t matrix_size) {
|
||||||
int idx = blockIdx.x * BLOCK + threadIdx.x;
|
auto idx = blockIdx.x * BLOCK + threadIdx.x;
|
||||||
|
|
||||||
if (idx >= matrix_size) {
|
if (idx >= matrix_size) {
|
||||||
return;
|
return;
|
||||||
@ -407,4 +407,4 @@ static __device__ half2 inline num2num2(const half x) {
|
|||||||
return __half2half2(x);
|
return __half2half2(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace allspark
|
} // namespace allspark
|
||||||
|
|||||||
@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
int n_tiles = size_n / tile_n_size;
|
int n_tiles = size_n / tile_n_size;
|
||||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||||
|
|
||||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||||
if (start_k_tile >= k_tiles) {
|
if (start_k_tile >= k_tiles) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
int4* sh_ptr = sh + stage_size * pipe;
|
int4* sh_ptr = sh + stage_size * pipe;
|
||||||
|
|
||||||
if (threadIdx.x < stage_size) {
|
if (threadIdx.x < stage_size) {
|
||||||
int k_id = threadIdx.x / stage_n_threads;
|
auto k_id = threadIdx.x / stage_n_threads;
|
||||||
int n_id = threadIdx.x % stage_n_threads;
|
auto n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
int first_k = k_tile_id * tile_k_size;
|
int first_k = k_tile_id * tile_k_size;
|
||||||
|
|
||||||
@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int th_id = threadIdx.x % 32;
|
auto th_id = threadIdx.x % 32;
|
||||||
|
|
||||||
if (warp_id >= 4) {
|
if (warp_id >= 4) {
|
||||||
return;
|
return;
|
||||||
@ -265,4 +265,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
|||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||||
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,7 +42,7 @@ namespace marlin {
|
|||||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||||
int const* __restrict__ perm_int_ptr,
|
int const* __restrict__ perm_int_ptr,
|
||||||
int4* __restrict__ out_int4_ptr, int size_m,
|
int4* __restrict__ out_int4_ptr, int size_m,
|
||||||
int size_k, int block_rows) {}
|
int size_k, int lda, int block_rows) {}
|
||||||
|
|
||||||
template <typename scalar_t, // compute dtype, half or nv_float16
|
template <typename scalar_t, // compute dtype, half or nv_float16
|
||||||
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
|
||||||
@ -459,29 +459,32 @@ __device__ inline void barrier_release(int* lock, bool reset = false) {
|
|||||||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
||||||
int const* __restrict__ perm_int_ptr,
|
int const* __restrict__ perm_int_ptr,
|
||||||
int4* __restrict__ out_int4_ptr, int size_m,
|
int4* __restrict__ out_int4_ptr, int size_m,
|
||||||
int size_k, int block_rows) {
|
int size_k, int lda, int block_rows) {
|
||||||
int start_row = block_rows * blockIdx.x;
|
auto start_row = block_rows * blockIdx.x;
|
||||||
int finish_row = start_row + block_rows;
|
int finish_row = start_row + block_rows;
|
||||||
if (finish_row > size_m) {
|
if (finish_row > size_m) {
|
||||||
finish_row = size_m;
|
finish_row = size_m;
|
||||||
}
|
}
|
||||||
int cur_block_rows = finish_row - start_row;
|
int cur_block_rows = finish_row - start_row;
|
||||||
|
|
||||||
int row_stride = size_k * sizeof(half) / 16;
|
int input_row_stride = lda * sizeof(half) / 16;
|
||||||
|
int output_row_stride = size_k * sizeof(half) / 16;
|
||||||
|
|
||||||
auto permute_row = [&](int row) {
|
auto permute_row = [&](int row) {
|
||||||
int iters = size_k / default_threads;
|
int iters = size_k / default_threads;
|
||||||
int rest = size_k % default_threads;
|
int rest = size_k % default_threads;
|
||||||
|
|
||||||
int offset = row * row_stride;
|
int input_offset = row * input_row_stride;
|
||||||
|
int output_offset = row * output_row_stride;
|
||||||
|
|
||||||
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
|
half const* a_row_half =
|
||||||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
|
reinterpret_cast<half const*>(a_int4_ptr + input_offset);
|
||||||
|
half* out_half = reinterpret_cast<half*>(out_int4_ptr + output_offset);
|
||||||
|
|
||||||
int base_k = 0;
|
int base_k = 0;
|
||||||
|
|
||||||
for (int i = 0; i < iters; i++) {
|
for (int i = 0; i < iters; i++) {
|
||||||
int cur_k = base_k + threadIdx.x;
|
auto cur_k = base_k + threadIdx.x;
|
||||||
int src_pos = perm_int_ptr[cur_k];
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
out_half[cur_k] = a_row_half[src_pos];
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
@ -491,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
|
|||||||
|
|
||||||
if (rest) {
|
if (rest) {
|
||||||
if (threadIdx.x < rest) {
|
if (threadIdx.x < rest) {
|
||||||
int cur_k = base_k + threadIdx.x;
|
auto cur_k = base_k + threadIdx.x;
|
||||||
int src_pos = perm_int_ptr[cur_k];
|
int src_pos = perm_int_ptr[cur_k];
|
||||||
|
|
||||||
out_half[cur_k] = a_row_half[src_pos];
|
out_half[cur_k] = a_row_half[src_pos];
|
||||||
@ -537,6 +540,7 @@ __global__ void Marlin(
|
|||||||
int prob_m, // batch dimension m
|
int prob_m, // batch dimension m
|
||||||
int prob_n, // output dimension n
|
int prob_n, // output dimension n
|
||||||
int prob_k, // reduction dimension k
|
int prob_k, // reduction dimension k
|
||||||
|
int lda, // A.stride(0), equal to prob_k is A is contiguous
|
||||||
int* locks, // extra global storage for barrier synchronization
|
int* locks, // extra global storage for barrier synchronization
|
||||||
bool use_atomic_add, // whether to use atomic add to reduce
|
bool use_atomic_add, // whether to use atomic add to reduce
|
||||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||||
@ -600,7 +604,7 @@ __global__ void Marlin(
|
|||||||
// We can easily implement parallel problem execution by just remapping
|
// We can easily implement parallel problem execution by just remapping
|
||||||
// indices and advancing global pointers
|
// indices and advancing global pointers
|
||||||
if (slice_col_par >= n_tiles) {
|
if (slice_col_par >= n_tiles) {
|
||||||
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
|
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8;
|
||||||
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
||||||
locks += (slice_col_par / n_tiles) * n_tiles;
|
locks += (slice_col_par / n_tiles) * n_tiles;
|
||||||
slice_col = slice_col_par % n_tiles;
|
slice_col = slice_col_par % n_tiles;
|
||||||
@ -631,7 +635,7 @@ __global__ void Marlin(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (slice_col == n_tiles) {
|
if (slice_col == n_tiles) {
|
||||||
A += 16 * thread_m_blocks * prob_k / 8;
|
A += 16 * thread_m_blocks * lda / 8;
|
||||||
C += 16 * thread_m_blocks * prob_n / 8;
|
C += 16 * thread_m_blocks * prob_n / 8;
|
||||||
locks += n_tiles;
|
locks += n_tiles;
|
||||||
slice_col = 0;
|
slice_col = 0;
|
||||||
@ -643,7 +647,7 @@ __global__ void Marlin(
|
|||||||
// A sizes/strides
|
// A sizes/strides
|
||||||
|
|
||||||
// stride of the A matrix in global memory
|
// stride of the A matrix in global memory
|
||||||
int a_gl_stride = prob_k / 8;
|
int a_gl_stride = lda / 8;
|
||||||
// stride of an A matrix tile in shared memory
|
// stride of an A matrix tile in shared memory
|
||||||
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
|
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
|
||||||
// delta between subsequent A tiles in global memory
|
// delta between subsequent A tiles in global memory
|
||||||
@ -719,8 +723,8 @@ __global__ void Marlin(
|
|||||||
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
|
||||||
b_gl_rd += b_sh_stride * slice_col;
|
b_gl_rd += b_sh_stride * slice_col;
|
||||||
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
b_gl_rd += b_gl_rd_delta_o * slice_row;
|
||||||
int b_sh_wr = threadIdx.x * b_thread_vecs;
|
auto b_sh_wr = threadIdx.x * b_thread_vecs;
|
||||||
int b_sh_rd = threadIdx.x * b_thread_vecs;
|
auto b_sh_rd = threadIdx.x * b_thread_vecs;
|
||||||
|
|
||||||
// For act_order
|
// For act_order
|
||||||
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
|
||||||
@ -739,7 +743,7 @@ __global__ void Marlin(
|
|||||||
s_sh_stride * slice_col + threadIdx.x;
|
s_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int s_sh_wr = threadIdx.x;
|
auto s_sh_wr = threadIdx.x;
|
||||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||||
|
|
||||||
// Zero-points
|
// Zero-points
|
||||||
@ -752,7 +756,7 @@ __global__ void Marlin(
|
|||||||
zp_sh_stride * slice_col + threadIdx.x;
|
zp_sh_stride * slice_col + threadIdx.x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int zp_sh_wr = threadIdx.x;
|
auto zp_sh_wr = threadIdx.x;
|
||||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||||
|
|
||||||
// We use a different scale layout for grouped and column-wise quantization as
|
// We use a different scale layout for grouped and column-wise quantization as
|
||||||
@ -1043,7 +1047,7 @@ __global__ void Marlin(
|
|||||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps = thread_n_blocks / 4;
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
int warp_row = warp_id / n_warps;
|
int warp_row = warp_id / n_warps;
|
||||||
@ -1081,7 +1085,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
// Determine "position" inside the thread-block (based on warp and
|
// Determine "position" inside the thread-block (based on warp and
|
||||||
// thread-id)
|
// thread-id)
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps =
|
int n_warps =
|
||||||
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
|
||||||
|
|
||||||
@ -1090,7 +1094,7 @@ __global__ void Marlin(
|
|||||||
|
|
||||||
cur_k += warp_row * 16;
|
cur_k += warp_row * 16;
|
||||||
|
|
||||||
int th_id = threadIdx.x % 32;
|
auto th_id = threadIdx.x % 32;
|
||||||
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
|
||||||
|
|
||||||
int s_col_shift =
|
int s_col_shift =
|
||||||
@ -1155,7 +1159,7 @@ __global__ void Marlin(
|
|||||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps = thread_n_blocks / 4;
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
int warp_row = warp_id / n_warps;
|
int warp_row = warp_id / n_warps;
|
||||||
@ -1193,7 +1197,7 @@ __global__ void Marlin(
|
|||||||
(pipe / (group_blocks / thread_k_blocks)));
|
(pipe / (group_blocks / thread_k_blocks)));
|
||||||
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
|
||||||
} else {
|
} else {
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int n_warps = thread_n_blocks / 4;
|
int n_warps = thread_n_blocks / 4;
|
||||||
|
|
||||||
int warp_row = warp_id / n_warps;
|
int warp_row = warp_id / n_warps;
|
||||||
@ -1319,7 +1323,7 @@ __global__ void Marlin(
|
|||||||
auto thread_block_reduce = [&]() {
|
auto thread_block_reduce = [&]() {
|
||||||
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
constexpr int red_off = threads / b_sh_stride_threads / 2;
|
||||||
if (red_off >= 1) {
|
if (red_off >= 1) {
|
||||||
int red_idx = threadIdx.x / b_sh_stride_threads;
|
auto red_idx = threadIdx.x / b_sh_stride_threads;
|
||||||
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
|
||||||
constexpr int red_sh_delta = b_sh_stride_threads;
|
constexpr int red_sh_delta = b_sh_stride_threads;
|
||||||
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
|
||||||
@ -1386,7 +1390,7 @@ __global__ void Marlin(
|
|||||||
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
4 * (threadIdx.x / 32) + threadIdx.x % 4;
|
||||||
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
c_gl_wr += (2 * thread_n_blocks) * slice_col;
|
||||||
constexpr int c_sh_wr_delta = active_threads;
|
constexpr int c_sh_wr_delta = active_threads;
|
||||||
int c_sh_wr = threadIdx.x;
|
auto c_sh_wr = threadIdx.x;
|
||||||
|
|
||||||
int row = (threadIdx.x % 32) / 4;
|
int row = (threadIdx.x % 32) / 4;
|
||||||
|
|
||||||
@ -1780,8 +1784,8 @@ __global__ void Marlin(
|
|||||||
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
|
||||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||||
num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
|
num_groups, prob_m, prob_n, prob_k, lda, locks, \
|
||||||
use_fp32_reduce); \
|
part_use_atomic_add, use_fp32_reduce); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2071,7 +2075,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
|||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||||
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||||
int prob_n, int prob_k, void* workspace,
|
int prob_n, int prob_k, int lda, void* workspace,
|
||||||
vllm::ScalarType const& q_type, bool has_act_order,
|
vllm::ScalarType const& q_type, bool has_act_order,
|
||||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||||
@ -2184,8 +2188,9 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
// Permute A columns
|
// Permute A columns
|
||||||
int block_rows = div_ceil(prob_m, blocks);
|
int block_rows = div_ceil(prob_m, blocks);
|
||||||
permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
|
permute_cols_kernel<<<blocks, default_threads, 0, stream>>>(
|
||||||
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows);
|
A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows);
|
||||||
A_ptr = a_tmp_ptr;
|
A_ptr = a_tmp_ptr;
|
||||||
|
lda = prob_k;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have a full K, then we can run the non-act-order version of Marlin
|
// If we have a full K, then we can run the non-act-order version of Marlin
|
||||||
@ -2210,6 +2215,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// atomic add reduce have better performance only when m * n is small
|
||||||
|
bool part_use_atomic_add =
|
||||||
|
use_atomic_add && div_ceil(prob_m, 64) * prob_n <= 2048;
|
||||||
|
|
||||||
if (false) {
|
if (false) {
|
||||||
}
|
}
|
||||||
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
|
||||||
@ -2244,7 +2253,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
|||||||
", num_bits = ", num_bits);
|
", num_bits = ", num_bits);
|
||||||
}
|
}
|
||||||
|
|
||||||
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
A_ptr += 16 * thread_m_blocks * (lda / 8) * par;
|
||||||
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
|
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2300,7 +2309,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
|
|
||||||
// Verify device and strides
|
// Verify device and strides
|
||||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||||
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
TORCH_CHECK(a.stride(1) == 1, "A.stride(1) is not 1");
|
||||||
|
// We use int4 (16 bytes) to load A, so A must aligned to 16 bytes
|
||||||
|
TORCH_CHECK(a.stride(0) % 8 == 0, "A.stride(0) must divisible by 8");
|
||||||
|
TORCH_CHECK(((uint64_t)a.data_ptr()) % 16 == 0, "A must aligned to 16 bytes");
|
||||||
|
|
||||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||||
@ -2432,7 +2444,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k, a.stride(0),
|
||||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||||
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
||||||
@ -2443,10 +2455,10 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||||
workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp,
|
a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full,
|
||||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
has_zp, num_groups, group_size, dev,
|
||||||
thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
|
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||||
use_fp32_reduce, is_zp_float);
|
marlin::max_par, use_atomic_add, use_fp32_reduce, is_zp_float);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
int n_tiles = size_n / tile_n_size;
|
int n_tiles = size_n / tile_n_size;
|
||||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||||
|
|
||||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
auto start_k_tile = blockIdx.x * block_k_tiles;
|
||||||
if (start_k_tile >= k_tiles) {
|
if (start_k_tile >= k_tiles) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
|
|
||||||
if constexpr (has_perm) {
|
if constexpr (has_perm) {
|
||||||
if (threadIdx.x < stage_size) {
|
if (threadIdx.x < stage_size) {
|
||||||
int k_id = threadIdx.x / stage_n_threads;
|
auto k_id = threadIdx.x / stage_n_threads;
|
||||||
int n_id = threadIdx.x % stage_n_threads;
|
auto n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
uint32_t const* sh_perm_int_ptr =
|
uint32_t const* sh_perm_int_ptr =
|
||||||
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
|
||||||
@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
if (threadIdx.x < stage_size) {
|
if (threadIdx.x < stage_size) {
|
||||||
int k_id = threadIdx.x / stage_n_threads;
|
auto k_id = threadIdx.x / stage_n_threads;
|
||||||
int n_id = threadIdx.x % stage_n_threads;
|
auto n_id = threadIdx.x % stage_n_threads;
|
||||||
|
|
||||||
int first_k = k_tile_id * tile_k_size;
|
int first_k = k_tile_id * tile_k_size;
|
||||||
int first_k_packed = first_k / pack_factor;
|
int first_k_packed = first_k / pack_factor;
|
||||||
@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int warp_id = threadIdx.x / 32;
|
auto warp_id = threadIdx.x / 32;
|
||||||
int th_id = threadIdx.x % 32;
|
auto th_id = threadIdx.x % 32;
|
||||||
|
|
||||||
if (warp_id >= 4) {
|
if (warp_id >= 4) {
|
||||||
return;
|
return;
|
||||||
@ -339,4 +339,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
|||||||
|
|
||||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
|
||||||
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user