mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-25 18:17:48 +08:00 
			
		
		
		
	Compare commits
	
		
			712 Commits
		
	
	
		
			v0.9.1rc2
			...
			codex/chan
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 5bb81e284b | |||
| 01513a334a | |||
| ac2bf41e53 | |||
| a931b4cdcf | |||
| a0f8a79646 | |||
| 18bdcf4113 | |||
| 1c3198b6c4 | |||
| 260127ea54 | |||
| d0dc4cfca4 | |||
| d31a647124 | |||
| 85431bd9ad | |||
| c11013db8b | |||
| 1eb2b9c102 | |||
| 6ebf313790 | |||
| cfbcb9ed87 | |||
| 76ddeff293 | |||
| f46098335b | |||
| e9534c7202 | |||
| 7976446015 | |||
| fcb9f879c1 | |||
| 3ed94f9d0a | |||
| fa839565f2 | |||
| 75a99b98bf | |||
| b5c3b68359 | |||
| 6cbc4d4bea | |||
| 153c6f1e61 | |||
| 34cda778a0 | |||
| 30800b01c2 | |||
| 10be209493 | |||
| 19c863068b | |||
| f29fd8a7f8 | |||
| ed10f3cea1 | |||
| b637e9dcb8 | |||
| 1e36c8687e | |||
| 5bac61362b | |||
| 313ae8c16a | |||
| c847e34b39 | |||
| e7e3e6d263 | |||
| 4ffd963fa0 | |||
| 56fe4bedd6 | |||
| d91278181d | |||
| 20149d84d9 | |||
| 3534c39a20 | |||
| c586b55667 | |||
| 33d560001e | |||
| f148c44c6a | |||
| 235bfd5dfe | |||
| 68d28e37b0 | |||
| 37a7d5d74a | |||
| d4d309409f | |||
| 85bd6599e4 | |||
| 91b3d190ae | |||
| fc017915f5 | |||
| 9ad0a4588b | |||
| 016b8d1b7f | |||
| 80305c1b24 | |||
| 37e2ecace2 | |||
| 054c8657e3 | |||
| d4170fad39 | |||
| 946aadb4a0 | |||
| bcdfb2a330 | |||
| ba8c300018 | |||
| 8cdc371217 | |||
| 61e20828da | |||
| 55e1c66da5 | |||
| 86f3ac21ce | |||
| 149f2435a5 | |||
| c0569dbc82 | |||
| 8bb43b9c9e | |||
| 559756214b | |||
| 6d0cf239c6 | |||
| 3fc964433a | |||
| 0caf61c08a | |||
| 667624659b | |||
| 38efa28278 | |||
| e8cc53af5e | |||
| a4851cfe68 | |||
| 9887e8ec50 | |||
| f326ab9c88 | |||
| dcf2a5e208 | |||
| 1e9438e0b0 | |||
| 697ef765ee | |||
| a99b9f7dee | |||
| c488b928a7 | |||
| 2c7fa47161 | |||
| 88fc8a97e3 | |||
| 66f6fbd393 | |||
| 8632e831ba | |||
| 4bbfc36b16 | |||
| 80d38b8ac8 | |||
| 211b6a6113 | |||
| 247102f07f | |||
| bd4c1e6fdb | |||
| 99b4f080d8 | |||
| 020f58abcd | |||
| c1acd6d7d4 | |||
| 3b3b778d4a | |||
| 42d440c22b | |||
| f45a332886 | |||
| 6e2c176e1f | |||
| a86754a12b | |||
| c2a2f19aba | |||
| 2c11a738b3 | |||
| b639327ad9 | |||
| 4afe687a82 | |||
| 5de8d9f111 | |||
| c1c8ca57ff | |||
| a3a5a47e48 | |||
| fb25e95688 | |||
| 0d4891cd03 | |||
| f56d2996ca | |||
| 147afb448b | |||
| 3c7d942da8 | |||
| 890323dc1b | |||
| 01cae37713 | |||
| 11c0198615 | |||
| b1235c3e10 | |||
| 44d02f54db | |||
| a8593237c0 | |||
| fc0f41d10a | |||
| 7b828e30d5 | |||
| 5f0af36af5 | |||
| 0d21b2664c | |||
| 9907fc4494 | |||
| d47661f0cd | |||
| 53fa457391 | |||
| 6fb162447b | |||
| 66177189c5 | |||
| b4f0b5f9aa | |||
| cbd14ed561 | |||
| 7bd4c37ae7 | |||
| 8020e98c9f | |||
| 762be26a8e | |||
| 6a9e6b2abf | |||
| 5d09152ff1 | |||
| 31d5c1797f | |||
| 35514b682a | |||
| e2de455c34 | |||
| 5b032352cc | |||
| 922f316441 | |||
| 5923ab9524 | |||
| 0cf893cae1 | |||
| cf75cd2098 | |||
| b854321ffe | |||
| 5b6fe23d05 | |||
| f0c98cae27 | |||
| 574ad60db9 | |||
| fdadb6f43a | |||
| 41060c6e08 | |||
| 3de2ed767f | |||
| 299252ea82 | |||
| d6902ce79f | |||
| 5e53c89a74 | |||
| c66e38ea4c | |||
| 251595368f | |||
| 4bed167768 | |||
| b140416abf | |||
| 5b8366b61a | |||
| c7753a9809 | |||
| 4b9a9435bb | |||
| 3482fd7e4e | |||
| 77f77a951e | |||
| 1a4f35e2ea | |||
| be1e128dfb | |||
| 65393ee064 | |||
| dc221ad72d | |||
| 7571a4a7e5 | |||
| f67d986dd1 | |||
| cc876d0f29 | |||
| fdfd409f8f | |||
| ffbcc9e757 | |||
| 59389c927b | |||
| 8f2720def9 | |||
| ad6c2e1a0b | |||
| 49e8c7ea25 | |||
| 805d62ca88 | |||
| b7d9e9416f | |||
| 7c12a765aa | |||
| cd587c93ef | |||
| 332d4cb17b | |||
| bf03ff3575 | |||
| 47043eb678 | |||
| 31b96d1c64 | |||
| e59ba9e142 | |||
| 403b481573 | |||
| 138709f8d1 | |||
| 0bbac1c1b4 | |||
| a3e4e85ece | |||
| eb58f5953d | |||
| 4ac9c33f78 | |||
| efe73d0575 | |||
| 853487bc1b | |||
| 9ff2af6d2b | |||
| 70ca5484f5 | |||
| 5358cce5ff | |||
| 2155e95ef1 | |||
| f95570a52d | |||
| b6e7e3d58f | |||
| e760fcef22 | |||
| 6bbf1795b7 | |||
| 9e0ef888f0 | |||
| 97abeb1daa | |||
| 34dad19e7b | |||
| 6db31e7a27 | |||
| 977180c912 | |||
| c40784c794 | |||
| baed180aa0 | |||
| 0b407479ef | |||
| 5eaf570050 | |||
| d8ee5a2ca4 | |||
| b9fca83256 | |||
| 32dffc2772 | |||
| c438183e99 | |||
| baba0389f7 | |||
| c6c22f16d3 | |||
| dd382e0fe3 | |||
| 849590a2a7 | |||
| a4c23314c0 | |||
| b942c094e3 | |||
| b4bab81660 | |||
| b91cb3fa5c | |||
| 71d1d75b7a | |||
| 72d14d0eed | |||
| e34d130c16 | |||
| 7721ef1786 | |||
| 8369b7c2a9 | |||
| 3eb4ad53f3 | |||
| 90a2769f20 | |||
| e60d422f19 | |||
| 0d914c81a2 | |||
| 6e428cdd7a | |||
| 93b9d9f499 | |||
| af107d5a0e | |||
| 31c5d0a1b7 | |||
| afb7cff1b9 | |||
| d2e841a10a | |||
| 14601f5fba | |||
| 042d131f39 | |||
| 8e807cdfa4 | |||
| e601efcb10 | |||
| 22dd9c2730 | |||
| a6d795d593 | |||
| a37d75bbec | |||
| edd270bc78 | |||
| 110df74332 | |||
| 1ad69e8375 | |||
| b8a498c9b2 | |||
| 923147b5e8 | |||
| 45877ef740 | |||
| 6e4bef1bea | |||
| 4ff79a136e | |||
| 448acad31e | |||
| eb0b2d2f08 | |||
| 3112271f6e | |||
| 1fd471e957 | |||
| 2c5ebec064 | |||
| 2e610deb72 | |||
| 6e2c19ce22 | |||
| 47db8c2c15 | |||
| 462b269280 | |||
| c18b3b8e8b | |||
| 9528e3a05e | |||
| 9fb52e523a | |||
| e202dd2736 | |||
| 43813e6361 | |||
| cede942b87 | |||
| fe1e924811 | |||
| 4548c03c50 | |||
| 40b86aa05e | |||
| 432870829d | |||
| f73d02aadc | |||
| c5ebe040ac | |||
| 8d763cb891 | |||
| cf4cd53982 | |||
| 32c9be2200 | |||
| 8aeaa910a2 | |||
| 906e05d840 | |||
| ef9a2990ae | |||
| 7e90870491 | |||
| d3f05c9248 | |||
| c108781c85 | |||
| 3d184b95b8 | |||
| 2f35a022e6 | |||
| ffe00ef77a | |||
| 5561681d04 | |||
| fbd62d8750 | |||
| 2e26f9156a | |||
| 9e5452ee34 | |||
| 0e3fe896e2 | |||
| 1caca5a589 | |||
| 783921d889 | |||
| 4a98edff1f | |||
| a7bab0c9e5 | |||
| 25950dca9b | |||
| a4113b035c | |||
| 7e1665b089 | |||
| 8d1096e7db | |||
| 8d775dd30a | |||
| 78fe77534b | |||
| 2f2fcb31b8 | |||
| 1dba2c4ebe | |||
| 71d6de3a26 | |||
| 536fd33003 | |||
| 619b9f5c7e | |||
| d1b689c445 | |||
| 9854dc9040 | |||
| ff5c60fad8 | |||
| 6f1229f91d | |||
| 1819fbda63 | |||
| 7f0367109e | |||
| fb14d53cf6 | |||
| b024a42e93 | |||
| cb97f2bfc5 | |||
| 359200f6ac | |||
| 220aee902a | |||
| 67d25eca05 | |||
| 363528de27 | |||
| 4ff61ababa | |||
| 0ec3779df7 | |||
| b616f6a53d | |||
| 2e25bb12a8 | |||
| 9965c47d0d | |||
| 059d4cdb49 | |||
| bdb84e26b0 | |||
| 3dd359147d | |||
| 657f2f301a | |||
| a1aafc827a | |||
| 139508a418 | |||
| d265414dbc | |||
| 48fb076cbc | |||
| c1909e7e8c | |||
| b95877509b | |||
| 706ff13224 | |||
| ccbfb1d1c9 | |||
| 9e5552aa13 | |||
| 0c600b9ab6 | |||
| e303dcf523 | |||
| ae9c4d416f | |||
| d853520b3e | |||
| ba51aea65e | |||
| 8452946c06 | |||
| 2e7cbf2d7d | |||
| 7da296be04 | |||
| b205e8467d | |||
| be0cfb2b68 | |||
| 1a03dd496b | |||
| 27b8017636 | |||
| 9ec1e3065a | |||
| 9dae7d46bf | |||
| 7058d7dd5d | |||
| a0389e0554 | |||
| 3be8d312a2 | |||
| 3abfe22154 | |||
| e81fbefe8a | |||
| 9290de5667 | |||
| 7f280d69c9 | |||
| 02cabff207 | |||
| 3d19d47d91 | |||
| 8acb4badee | |||
| 314af8617c | |||
| 0e96cc9b7e | |||
| ecad851cbd | |||
| ed70f3c64f | |||
| 650d5dbd04 | |||
| 9025a9a705 | |||
| c05596f1a3 | |||
| 787b13389e | |||
| 96453cfa83 | |||
| b1c1fe35a5 | |||
| 08d81f1014 | |||
| 6cc1e7d96d | |||
| 9909726d2a | |||
| 22e9d42040 | |||
| 86debab54c | |||
| be250bbc67 | |||
| 27949354fa | |||
| bd5038af07 | |||
| a2f14dc8f9 | |||
| 92ee7baaf9 | |||
| 7151f92241 | |||
| e28533a16f | |||
| 6d42ce8315 | |||
| ded1fb635b | |||
| 97d9524fe9 | |||
| d8cf819a9a | |||
| 551ef1631a | |||
| 2863befce3 | |||
| 2965c99c86 | |||
| 2062c0723d | |||
| 1c50e100a9 | |||
| 3ee56e26be | |||
| 8fe7fc8634 | |||
| e936e401de | |||
| f5dfa07531 | |||
| 022c58b80f | |||
| 19108ef311 | |||
| 5a52f389dd | |||
| 65b1cbb138 | |||
| 6c9837a761 | |||
| 6f2f53a82d | |||
| 7b1895e6ce | |||
| 4d36693687 | |||
| daec9dea6e | |||
| daceac57c7 | |||
| 8615d9776f | |||
| 7b460c25f9 | |||
| f719772281 | |||
| d45417b804 | |||
| a29e62ea34 | |||
| e53be6f00a | |||
| c329ceca6d | |||
| 3c545c0c3b | |||
| e8c3bd2cd1 | |||
| c6c983053d | |||
| aafabaa0d5 | |||
| 94a55c7681 | |||
| aa0dc77ef5 | |||
| 4ab3ac285e | |||
| d1c956dc0f | |||
| dec197e3e5 | |||
| 6e244ae091 | |||
| cd4cfee689 | |||
| e110930680 | |||
| 8b64c895c0 | |||
| 0740e29b66 | |||
| 44d2e6af63 | |||
| 2d7779f888 | |||
| a57d57fa72 | |||
| 71799fd005 | |||
| e9fd658a73 | |||
| 07b8fae219 | |||
| 562308816c | |||
| 04e1642e32 | |||
| b69781f107 | |||
| 0bceac9810 | |||
| 34878a0b48 | |||
| 6393b03986 | |||
| 0907d507bf | |||
| c894c5dc1f | |||
| 1f5d178e9c | |||
| 27c065df50 | |||
| 84c260caeb | |||
| 167aca45cb | |||
| 0567c8249f | |||
| d188913d99 | |||
| 1d7c29f5fe | |||
| 65397e40f5 | |||
| 9502c38138 | |||
| 2582683566 | |||
| 754b00edb3 | |||
| 296ce95d8e | |||
| 2d7620c3eb | |||
| 55c65ab495 | |||
| 2cc2069970 | |||
| 9f0608fc16 | |||
| 4e0db57fff | |||
| c40692bf9a | |||
| 4734704b30 | |||
| 8b8c209e35 | |||
| 23a04e0895 | |||
| 02c97d9a92 | |||
| e795d723ed | |||
| 8359f4c8d8 | |||
| bf5181583f | |||
| c53fec1fcb | |||
| 0f9e7354f5 | |||
| ba7ba35cda | |||
| 015fab8c2f | |||
| f59fc60fb3 | |||
| 879f69bed3 | |||
| 7108934142 | |||
| 3443aaf8dd | |||
| 2273ec322c | |||
| a6c4b87fbc | |||
| 1afa9948f5 | |||
| 0d06b533a0 | |||
| c01d1c5aba | |||
| ead369845d | |||
| c6e3bba8e6 | |||
| 91f7d9d0b6 | |||
| 8619e7158c | |||
| c635c5f744 | |||
| a045b7e89a | |||
| 981eeca41a | |||
| 26d34eb67e | |||
| 53da4cd397 | |||
| 9a3b88328f | |||
| 3014c920da | |||
| 0eed516951 | |||
| ee5ad8d2c5 | |||
| a738dbb2a1 | |||
| 33d5e29be9 | |||
| 4671ac6e2a | |||
| dd2ccf8dde | |||
| a3bc76e4b5 | |||
| e6327c9b3e | |||
| d0132f025d | |||
| 61f4fc5dc6 | |||
| 68aaeb3749 | |||
| c3649e4fee | |||
| 53243e5c42 | |||
| a6e6604d32 | |||
| b82e0f82cb | |||
| 5111642a6f | |||
| 1bcd15edc7 | |||
| 2ebff5b77c | |||
| f17aec0d63 | |||
| 493c275352 | |||
| f39ab2d4bd | |||
| 4a0f7888a3 | |||
| c4cf260677 | |||
| 33d51f599e | |||
| e91386cde1 | |||
| 2c11a29f0b | |||
| c76a506bd6 | |||
| ec0db6f51c | |||
| c305a2109d | |||
| 202c5df935 | |||
| 2bb246b8f7 | |||
| 4c409cabc2 | |||
| 3b1e4c6a23 | |||
| 2c5302fadd | |||
| caa680fd2e | |||
| c3bf9bad11 | |||
| 6f170f11dd | |||
| 8ca81bb069 | |||
| e773a9e1c2 | |||
| 71baf85ae1 | |||
| 79f2f1c2a1 | |||
| 2e3e3c86dc | |||
| 7e8977fcd4 | |||
| f1e840e842 | |||
| 7771d1de88 | |||
| 71d1219545 | |||
| e384f2f108 | |||
| 089a306f19 | |||
| 5e666f72cd | |||
| e3a3e4db46 | |||
| e41bf15cd0 | |||
| 5aa4a015ce | |||
| b6bad3d186 | |||
| ee9a1531aa | |||
| 10d82f9ac5 | |||
| ea10dd9d9e | |||
| ead2110297 | |||
| 01220ce89a | |||
| 6f68c49220 | |||
| 4719460644 | |||
| 466166dcfd | |||
| 1d0ae26c85 | |||
| 6021999573 | |||
| c7b370c603 | |||
| aa20d10a91 | |||
| 2de12be428 | |||
| 83ca9ae47b | |||
| e2148dc5ea | |||
| b1098b4072 | |||
| 799397ee4f | |||
| 4959915089 | |||
| 8d1e89d946 | |||
| 36239f79dd | |||
| dfada85eee | |||
| ed33349738 | |||
| d49adea1f9 | |||
| 14fdd21d39 | |||
| 04fefe7c9a | |||
| 3b523e38d9 | |||
| 16c16301c8 | |||
| 9206d0ff01 | |||
| a89209b78d | |||
| ffacb222cb | |||
| 12575cfa7a | |||
| 8b6e1d639c | |||
| 735a9de71f | |||
| 257ab95439 | |||
| cca91a7a10 | |||
| f04d604567 | |||
| 19a53b2783 | |||
| eccdc8318c | |||
| 5f52a84685 | |||
| d4629dc43f | |||
| 6e9cc73f67 | |||
| c53711bd63 | |||
| dac8cc49f4 | |||
| a44b1c951d | |||
| b447624ee3 | |||
| cda92307c1 | |||
| bf57ccc5c2 | |||
| ffb2cd6b54 | |||
| ca94d7fa00 | |||
| 5a1c2e15d8 | |||
| 4c8f64faa7 | |||
| 93aee29fdb | |||
| 154d063b9f | |||
| ccd7c05089 | |||
| c48c6c4008 | |||
| aed8468642 | |||
| 5c76b9cdaf | |||
| ddfed314f9 | |||
| 5b3ad5ecf2 | |||
| ede5c4ebdf | |||
| 07334959d8 | |||
| 119f683949 | |||
| 0860087aff | |||
| 6bc7b57315 | |||
| 90f9c2eb5c | |||
| 387bdf0ab9 | |||
| 5e5baa91aa | |||
| 836d4ce140 | |||
| c3fec47bb7 | |||
| 1173804dca | |||
| 4d5424029b | |||
| 3e7506975c | |||
| ee35e96ac3 | |||
| dec66d253b | |||
| 8d120701fd | |||
| f40f763f12 | |||
| 26bc46ef89 | |||
| a77aea59fd | |||
| b692e9cd07 | |||
| 367871a469 | |||
| 92183b41f3 | |||
| c6703d1e0d | |||
| a5e7242d5f | |||
| 91b2c17a55 | |||
| 055915e6ce | |||
| 3d330c4c09 | |||
| 0b73736a0d | |||
| ee1531bc38 | |||
| e13945f9dd | |||
| 08500011d3 | |||
| 861a0a0a39 | |||
| bc956b38d0 | |||
| 294fc1e2c9 | |||
| 2db9044ab6 | |||
| 6fa718a460 | |||
| 06be858828 | |||
| d1e34cc9ac | |||
| bd517eb9fe | |||
| d65668b4e8 | |||
| aafbbd981f | |||
| 0f0874515a | |||
| 3597b06a4f | |||
| 1015296b79 | |||
| ce9dc02c93 | |||
| a24cb91600 | |||
| 7e8d97dd3f | |||
| d70bc7c029 | |||
| ce688ad46e | |||
| cefdb9962d | |||
| ace5cdaff0 | |||
| 6458721108 | |||
| bb4a0decef | |||
| c707cfc12e | |||
| 7b3c9ff91d | |||
| c68698b326 | |||
| e3b12667d4 | |||
| e6aab5de29 | |||
| c57bb199b3 | |||
| dba68f9159 | |||
| a3319f4f04 | |||
| 9d880f594d | |||
| 017ef648e9 | |||
| 4b25ab14e2 | |||
| f98548b9da | |||
| 96846bb360 | |||
| b6efafd9e4 | |||
| 1129e2b1ab | |||
| c742438f8b | |||
| 73e2e0118f | |||
| c9280e6346 | |||
| af09b3f0a0 | |||
| 4f6c42fa0a | |||
| dff680001d | |||
| 2e090bd5df | |||
| 1b0b065eb5 | |||
| d5bdf899e4 | |||
| 7e3e74c97c | |||
| 3f6341bf7f | |||
| e5d35d62f5 | |||
| 2f1c19b245 | |||
| 42f52cc95b | |||
| 97a9465bbc | |||
| c7ea0b56cd | |||
| 29fa5cac1c | |||
| b2d9be6f7d | |||
| 04a55612dd | |||
| 89b0f84e17 | |||
| 497a91e9f7 | |||
| 943ffa5703 | |||
| 5c8d34a42c | |||
| 3c8694eabe | |||
| 7484e1fce2 | |||
| a2142f0196 | |||
| 871d6b7c74 | |||
| 29a38f0352 | |||
| a5115f4ff5 | |||
| 68b4a26149 | |||
| b8e809a057 | |||
| 5039ec2336 | |||
| 7c644ab6d5 | |||
| 2d40665fe8 | |||
| 96ada386b7 | |||
| 1e473b3010 | |||
| 2b1e2111b0 | |||
| a45b979d9f | |||
| 3952731e8f | |||
| 77f0d465d0 | |||
| 22c3c0aa4a | |||
| 33f8dba7c6 | |||
| 5241ca50d6 | |||
| da9b523ce1 | 
| @ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do | |||||||
| done | done | ||||||
|  |  | ||||||
| lm_eval --model vllm \ | lm_eval --model vllm \ | ||||||
|   --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \ |   --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \ | ||||||
|   --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ |   --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ | ||||||
|   --batch_size "$BATCH_SIZE" |   --batch_size "$BATCH_SIZE" | ||||||
|  | |||||||
| @ -18,12 +18,14 @@ RTOL = 0.08 | |||||||
|  |  | ||||||
| def launch_lm_eval(eval_config, tp_size): | def launch_lm_eval(eval_config, tp_size): | ||||||
|     trust_remote_code = eval_config.get("trust_remote_code", False) |     trust_remote_code = eval_config.get("trust_remote_code", False) | ||||||
|  |     max_model_len = eval_config.get("max_model_len", 4096) | ||||||
|     model_args = ( |     model_args = ( | ||||||
|         f"pretrained={eval_config['model_name']}," |         f"pretrained={eval_config['model_name']}," | ||||||
|         f"tensor_parallel_size={tp_size}," |         f"tensor_parallel_size={tp_size}," | ||||||
|         f"enforce_eager=true," |         f"enforce_eager=true," | ||||||
|         f"add_bos_token=true," |         f"add_bos_token=true," | ||||||
|         f"trust_remote_code={trust_remote_code}" |         f"trust_remote_code={trust_remote_code}," | ||||||
|  |         f"max_model_len={max_model_len}" | ||||||
|     ) |     ) | ||||||
|     results = lm_eval.simple_evaluate( |     results = lm_eval.simple_evaluate( | ||||||
|         model="vllm", |         model="vllm", | ||||||
|  | |||||||
| @ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc | |||||||
|  |  | ||||||
| ## Performance benchmark quick overview | ## Performance benchmark quick overview | ||||||
|  |  | ||||||
| **Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models. | **Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models. | ||||||
|  |  | ||||||
| **Benchmarking Duration**: about 1hr. | **Benchmarking Duration**: about 1hr. | ||||||
|  |  | ||||||
| @ -31,13 +31,27 @@ Performance benchmark will be triggered when: | |||||||
| - A PR being merged into vllm. | - A PR being merged into vllm. | ||||||
| - Every commit for those PRs with `perf-benchmarks` label AND `ready` label. | - Every commit for those PRs with `perf-benchmarks` label AND `ready` label. | ||||||
|  |  | ||||||
|  | Manually Trigger the benchmark | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Runtime environment variables: | ||||||
|  | - `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0. | ||||||
|  | - `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file). | ||||||
|  | - `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file). | ||||||
|  | - `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file). | ||||||
|  | - `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string. | ||||||
|  | - `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string. | ||||||
|  |  | ||||||
| Nightly benchmark will be triggered when: | Nightly benchmark will be triggered when: | ||||||
| - Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. | - Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. | ||||||
|  |  | ||||||
| ## Performance benchmark details | ## Performance benchmark details | ||||||
|  |  | ||||||
| See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. | See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. | ||||||
|  | > NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead. | ||||||
| ### Latency test | ### Latency test | ||||||
|  |  | ||||||
| Here is an example of one test inside `latency-tests.json`: | Here is an example of one test inside `latency-tests.json`: | ||||||
| @ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running. | |||||||
| The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. | The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. | ||||||
| The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. | The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. | ||||||
|  |  | ||||||
|  | The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`. | ||||||
|  | When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`. | ||||||
|  | `compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. | ||||||
|  |  | ||||||
|  | Here is an example using the script to compare result_a and result_b without detail test name. | ||||||
|  | `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name` | ||||||
|  |  | ||||||
|  | |    | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio        | | ||||||
|  | |----|----------------------------------------|----------------------------------------|----------| | ||||||
|  | | 0  | 142.633982                             | 156.526018                             | 1.097396 | | ||||||
|  | | 1  | 241.620334                             | 294.018783                             | 1.216863 | | ||||||
|  | | 2  | 218.298905                             | 262.664916                             | 1.203235 | | ||||||
|  | | 3  | 242.743860                             | 299.816190                             | 1.235113 | | ||||||
|  |  | ||||||
|  | Here is an example using the script to compare result_a and result_b with detail test name. | ||||||
|  | `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` | ||||||
|  | |   | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio        | | ||||||
|  | |---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------| | ||||||
|  | | 0 | serving_llama8B_tp1_sharegpt_qps_1          | 142.633982                             | serving_llama8B_tp1_sharegpt_qps_1          | 156.526018                             | 1.097396 | | ||||||
|  | | 1 | serving_llama8B_tp1_sharegpt_qps_16         | 241.620334                             | serving_llama8B_tp1_sharegpt_qps_16         | 294.018783                             | 1.216863 | | ||||||
|  | | 2 | serving_llama8B_tp1_sharegpt_qps_4          | 218.298905                             | serving_llama8B_tp1_sharegpt_qps_4          | 262.664916                             | 1.203235 | | ||||||
|  | | 3 | serving_llama8B_tp1_sharegpt_qps_inf        | 242.743860                             | serving_llama8B_tp1_sharegpt_qps_inf        | 299.816190                             | 1.235113 | | ||||||
|  | | 4 | serving_llama8B_tp2_random_1024_128_qps_1   | 96.613390                              | serving_llama8B_tp4_random_1024_128_qps_1   | 108.404853                             | 1.122048 | | ||||||
|  |  | ||||||
| ## Nightly test details | ## Nightly test details | ||||||
|  |  | ||||||
| See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. | See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. | ||||||
|  | |||||||
| @ -16,7 +16,7 @@ Please download the visualization scripts in the post | |||||||
|   - Download `nightly-benchmarks.zip`. |   - Download `nightly-benchmarks.zip`. | ||||||
|   - In the same folder, run the following code: |   - In the same folder, run the following code: | ||||||
|  |  | ||||||
|   ```console |   ```bash | ||||||
|   export HF_TOKEN=<your HF token> |   export HF_TOKEN=<your HF token> | ||||||
|   apt update |   apt update | ||||||
|   apt install -y git |   apt install -y git | ||||||
|  | |||||||
| @ -4,7 +4,8 @@ | |||||||
| - Input length: 32 tokens. | - Input length: 32 tokens. | ||||||
| - Output length: 128 tokens. | - Output length: 128 tokens. | ||||||
| - Batch size: fixed (8). | - Batch size: fixed (8). | ||||||
| - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | - GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||||
|  | - CPU Models: llama-3.1 8B. | ||||||
| - Evaluation metrics: end-to-end latency (mean, median, p99). | - Evaluation metrics: end-to-end latency (mean, median, p99). | ||||||
|  |  | ||||||
| {latency_tests_markdown_table} | {latency_tests_markdown_table} | ||||||
| @ -14,7 +15,8 @@ | |||||||
| - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). | - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). | ||||||
| - Output length: the corresponding output length of these 200 prompts. | - Output length: the corresponding output length of these 200 prompts. | ||||||
| - Batch size: dynamically determined by vllm to achieve maximum throughput. | - Batch size: dynamically determined by vllm to achieve maximum throughput. | ||||||
| - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | - GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||||
|  | - CPU Models: llama-3.1 8B. | ||||||
| - Evaluation metrics: throughput. | - Evaluation metrics: throughput. | ||||||
|  |  | ||||||
| {throughput_tests_markdown_table} | {throughput_tests_markdown_table} | ||||||
| @ -25,12 +27,18 @@ | |||||||
| - Output length: the corresponding output length of these 200 prompts. | - Output length: the corresponding output length of these 200 prompts. | ||||||
| - Batch size: dynamically determined by vllm and the arrival pattern of the requests. | - Batch size: dynamically determined by vllm and the arrival pattern of the requests. | ||||||
| - **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). | - **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). | ||||||
| - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | - GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||||
| - We also added a speculative decoding test for llama-3 70B, under QPS 2 | - We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2 | ||||||
|  | - CPU Models: llama-3.1 8B. | ||||||
| - Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). | - Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). | ||||||
|  | - For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts. | ||||||
|  |  | ||||||
| {serving_tests_markdown_table} | {serving_tests_markdown_table} | ||||||
|  |  | ||||||
|  | ## Platform Information | ||||||
|  |  | ||||||
|  | {platform_markdown_table} | ||||||
|  |  | ||||||
| ## json version of the benchmarking tables | ## json version of the benchmarking tables | ||||||
|  |  | ||||||
| This section contains the data of the markdown tables above in JSON format. | This section contains the data of the markdown tables above in JSON format. | ||||||
|  | |||||||
| @ -0,0 +1,66 @@ | |||||||
|  | # SPDX-License-Identifier: Apache-2.0 | ||||||
|  | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|  | import argparse | ||||||
|  |  | ||||||
|  | import pandas as pd | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def compare_data_columns( | ||||||
|  |     files, name_column, data_column, drop_column, ignore_test_name=False | ||||||
|  | ): | ||||||
|  |     print("\ncompare_data_column: " + data_column) | ||||||
|  |     frames = [] | ||||||
|  |     compare_frames = [] | ||||||
|  |     for file in files: | ||||||
|  |         data_df = pd.read_json(file) | ||||||
|  |         serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) | ||||||
|  |         if ignore_test_name is False: | ||||||
|  |             serving_df = serving_df.rename(columns={name_column: file + "_name"}) | ||||||
|  |             frames.append(serving_df[file + "_name"]) | ||||||
|  |         serving_df = serving_df.rename(columns={data_column: file}) | ||||||
|  |         frames.append(serving_df[file]) | ||||||
|  |         compare_frames.append(serving_df[file]) | ||||||
|  |         if len(compare_frames) >= 2: | ||||||
|  |             # Compare numbers among two files | ||||||
|  |             ratio_df = compare_frames[1] / compare_frames[0] | ||||||
|  |             frames.append(ratio_df) | ||||||
|  |             compare_frames.pop(1) | ||||||
|  |  | ||||||
|  |     concat_df = pd.concat(frames, axis=1) | ||||||
|  |     return concat_df | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument( | ||||||
|  |         "-f", "--file", action="append", type=str, help="input file name" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--ignore_test_name", action="store_true", help="ignore_test_name or not" | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     files = args.file | ||||||
|  |     print("comparing : " + ", ".join(files)) | ||||||
|  |  | ||||||
|  |     drop_column = "P99" | ||||||
|  |     name_column = "Test name" | ||||||
|  |     data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"] | ||||||
|  |     html_msgs_for_data_cols = [ | ||||||
|  |         "Compare Output Tokens /n", | ||||||
|  |         "Median TTFT /n", | ||||||
|  |         "Median TPOT /n", | ||||||
|  |     ] | ||||||
|  |     ignore_test_name = args.ignore_test_name | ||||||
|  |     with open("perf_comparison.html", "w") as text_file: | ||||||
|  |         for i in range(len(data_cols_to_compare)): | ||||||
|  |             output_df = compare_data_columns( | ||||||
|  |                 files, | ||||||
|  |                 name_column, | ||||||
|  |                 data_cols_to_compare[i], | ||||||
|  |                 drop_column, | ||||||
|  |                 ignore_test_name=ignore_test_name, | ||||||
|  |             ) | ||||||
|  |             print(output_df) | ||||||
|  |             html = output_df.to_html() | ||||||
|  |             text_file.write(html_msgs_for_data_cols[i]) | ||||||
|  |             text_file.write(html) | ||||||
| @ -3,9 +3,11 @@ | |||||||
|  |  | ||||||
| import json | import json | ||||||
| import os | import os | ||||||
|  | from importlib import util | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
|  |  | ||||||
| import pandas as pd | import pandas as pd | ||||||
|  | import psutil | ||||||
| from tabulate import tabulate | from tabulate import tabulate | ||||||
|  |  | ||||||
| results_folder = Path("results/") | results_folder = Path("results/") | ||||||
| @ -29,11 +31,11 @@ throughput_results = [] | |||||||
| throughput_results_column_mapping = { | throughput_results_column_mapping = { | ||||||
|     "test_name": "Test name", |     "test_name": "Test name", | ||||||
|     "gpu_type": "GPU", |     "gpu_type": "GPU", | ||||||
|     # "num_requests": "# of req.", |     "num_requests": "# of req.", | ||||||
|     # "total_num_tokens": "Total # of tokens", |     "total_num_tokens": "Total # of tokens", | ||||||
|     # "elapsed_time": "Elapsed time (s)", |     "elapsed_time": "Elapsed time (s)", | ||||||
|     "requests_per_second": "Tput (req/s)", |     "requests_per_second": "Tput (req/s)", | ||||||
|     # "tokens_per_second": "Tput (tok/s)", |     "tokens_per_second": "Tput (tok/s)", | ||||||
| } | } | ||||||
|  |  | ||||||
| # serving results and the keys that will be printed into markdown | # serving results and the keys that will be printed into markdown | ||||||
| @ -41,16 +43,18 @@ serving_results = [] | |||||||
| serving_column_mapping = { | serving_column_mapping = { | ||||||
|     "test_name": "Test name", |     "test_name": "Test name", | ||||||
|     "gpu_type": "GPU", |     "gpu_type": "GPU", | ||||||
|     # "completed": "# of req.", |     "completed": "# of req.", | ||||||
|     "request_throughput": "Tput (req/s)", |     "request_throughput": "Tput (req/s)", | ||||||
|     # "input_throughput": "Input Tput (tok/s)", |     "total_token_throughput": "Total Token Tput (tok/s)", | ||||||
|     # "output_throughput": "Output Tput (tok/s)", |     "output_throughput": "Output Tput (tok/s)", | ||||||
|  |     "total_input_tokens": "Total input tokens", | ||||||
|  |     "total_output_tokens": "Total output tokens", | ||||||
|     "mean_ttft_ms": "Mean TTFT (ms)", |     "mean_ttft_ms": "Mean TTFT (ms)", | ||||||
|     "median_ttft_ms": "Median TTFT (ms)", |     "median_ttft_ms": "Median TTFT (ms)", | ||||||
|     "p99_ttft_ms": "P99 TTFT (ms)", |     "p99_ttft_ms": "P99 TTFT (ms)", | ||||||
|     # "mean_tpot_ms": "Mean TPOT (ms)", |     "mean_tpot_ms": "Mean TPOT (ms)", | ||||||
|     # "median_tpot_ms": "Median", |     "median_tpot_ms": "Median", | ||||||
|     # "p99_tpot_ms": "P99", |     "p99_tpot_ms": "P99", | ||||||
|     "mean_itl_ms": "Mean ITL (ms)", |     "mean_itl_ms": "Mean ITL (ms)", | ||||||
|     "median_itl_ms": "Median ITL (ms)", |     "median_itl_ms": "Median ITL (ms)", | ||||||
|     "p99_itl_ms": "P99 ITL (ms)", |     "p99_itl_ms": "P99 ITL (ms)", | ||||||
| @ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_size_with_unit(bytes, suffix="B"): | ||||||
|  |     """ | ||||||
|  |     Scale bytes to its proper format | ||||||
|  |     e.g: | ||||||
|  |         1253656 => '1.20MB' | ||||||
|  |         1253656678 => '1.17GB' | ||||||
|  |     """ | ||||||
|  |     factor = 1024 | ||||||
|  |     for unit in ["", "K", "M", "G", "T", "P"]: | ||||||
|  |         if bytes < factor: | ||||||
|  |             return f"{bytes:.2f}{unit}{suffix}" | ||||||
|  |         bytes /= factor | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # collect results |     # collect results | ||||||
|     for test_file in results_folder.glob("*.json"): |     for test_file in results_folder.glob("*.json"): | ||||||
| @ -155,6 +173,27 @@ if __name__ == "__main__": | |||||||
|     serving_results = pd.DataFrame.from_dict(serving_results) |     serving_results = pd.DataFrame.from_dict(serving_results) | ||||||
|     throughput_results = pd.DataFrame.from_dict(throughput_results) |     throughput_results = pd.DataFrame.from_dict(throughput_results) | ||||||
|  |  | ||||||
|  |     svmem = psutil.virtual_memory() | ||||||
|  |     platform_data = { | ||||||
|  |         "Physical cores": [psutil.cpu_count(logical=False)], | ||||||
|  |         "Total cores": [psutil.cpu_count(logical=True)], | ||||||
|  |         "Total Memory": [get_size_with_unit(svmem.total)], | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if util.find_spec("numa") is not None: | ||||||
|  |         from numa import info | ||||||
|  |  | ||||||
|  |         platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()] | ||||||
|  |  | ||||||
|  |     if util.find_spec("cpuinfo") is not None: | ||||||
|  |         from cpuinfo import get_cpu_info | ||||||
|  |  | ||||||
|  |         platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]] | ||||||
|  |  | ||||||
|  |     platform_results = pd.DataFrame.from_dict( | ||||||
|  |         platform_data, orient="index", columns=["Platform Info"] | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     raw_results_json = results_to_json( |     raw_results_json = results_to_json( | ||||||
|         latency_results, throughput_results, serving_results |         latency_results, throughput_results, serving_results | ||||||
|     ) |     ) | ||||||
| @ -200,6 +239,9 @@ if __name__ == "__main__": | |||||||
|     throughput_md_table = tabulate( |     throughput_md_table = tabulate( | ||||||
|         throughput_results, headers="keys", tablefmt="pipe", showindex=False |         throughput_results, headers="keys", tablefmt="pipe", showindex=False | ||||||
|     ) |     ) | ||||||
|  |     platform_md_table = tabulate( | ||||||
|  |         platform_results, headers="keys", tablefmt="pipe", showindex=True | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     # document the result |     # document the result | ||||||
|     with open(results_folder / "benchmark_results.md", "w") as f: |     with open(results_folder / "benchmark_results.md", "w") as f: | ||||||
| @ -211,6 +253,7 @@ if __name__ == "__main__": | |||||||
|             latency_tests_markdown_table=latency_md_table, |             latency_tests_markdown_table=latency_md_table, | ||||||
|             throughput_tests_markdown_table=throughput_md_table, |             throughput_tests_markdown_table=throughput_md_table, | ||||||
|             serving_tests_markdown_table=serving_md_table, |             serving_tests_markdown_table=serving_md_table, | ||||||
|  |             platform_markdown_table=platform_md_table, | ||||||
|             benchmarking_results_in_json_string=processed_results_json, |             benchmarking_results_in_json_string=processed_results_json, | ||||||
|         ) |         ) | ||||||
|         f.write(results) |         f.write(results) | ||||||
|  | |||||||
| @ -31,6 +31,20 @@ check_gpus() { | |||||||
|   echo "GPU type is $gpu_type" |   echo "GPU type is $gpu_type" | ||||||
| } | } | ||||||
|  |  | ||||||
|  | check_cpus() { | ||||||
|  |   # check the number of CPUs and NUMA Node and GPU type. | ||||||
|  |   declare -g numa_count=$(python3 -c  "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)") | ||||||
|  |   if [[ $numa_count -gt 0 ]]; then | ||||||
|  |     echo "NUMA found." | ||||||
|  |     echo $numa_count | ||||||
|  |   else | ||||||
|  |     echo "Need at least 1 NUMA to run benchmarking." | ||||||
|  |     exit 1 | ||||||
|  |   fi | ||||||
|  |   declare -g gpu_type="cpu" | ||||||
|  |   echo "GPU type is $gpu_type" | ||||||
|  | } | ||||||
|  |  | ||||||
| check_hf_token() { | check_hf_token() { | ||||||
|   # check if HF_TOKEN is available and valid |   # check if HF_TOKEN is available and valid | ||||||
|   if [[ -z "$HF_TOKEN" ]]; then |   if [[ -z "$HF_TOKEN" ]]; then | ||||||
| @ -69,6 +83,22 @@ json2args() { | |||||||
|   echo "$args" |   echo "$args" | ||||||
| } | } | ||||||
|  |  | ||||||
|  | json2envs() { | ||||||
|  |   # transforms the JSON string to environment variables. | ||||||
|  |   # example: | ||||||
|  |   # input: { "VLLM_CPU_KVCACHE_SPACE": 5 } | ||||||
|  |   # output: VLLM_CPU_KVCACHE_SPACE=5 | ||||||
|  |   local json_string=$1 | ||||||
|  |   local args=$( | ||||||
|  |     echo "$json_string" | jq -r ' | ||||||
|  |       to_entries | | ||||||
|  |       map((.key ) + "=" + (.value | tostring)) | | ||||||
|  |       join(" ") | ||||||
|  |     ' | ||||||
|  |   ) | ||||||
|  |   echo "$args" | ||||||
|  | } | ||||||
|  |  | ||||||
| wait_for_server() { | wait_for_server() { | ||||||
|   # wait for vllm server to start |   # wait for vllm server to start | ||||||
|   # return 1 if vllm server crashes |   # return 1 if vllm server crashes | ||||||
| @ -158,15 +188,24 @@ run_latency_tests() { | |||||||
|     # get arguments |     # get arguments | ||||||
|     latency_params=$(echo "$params" | jq -r '.parameters') |     latency_params=$(echo "$params" | jq -r '.parameters') | ||||||
|     latency_args=$(json2args "$latency_params") |     latency_args=$(json2args "$latency_params") | ||||||
|  |     latency_environment_variables=$(echo "$params" | jq -r '.environment_variables') | ||||||
|  |     latency_envs=$(json2envs "$latency_environment_variables") | ||||||
|  |  | ||||||
|     # check if there is enough GPU to run the test |     # check if there is enough GPU to run the test | ||||||
|     tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') |     tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') | ||||||
|     if [[ $gpu_count -lt $tp ]]; then |     if [ "$ON_CPU" == "1" ];then | ||||||
|       echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." |       if [[ $numa_count -lt $tp ]]; then | ||||||
|       continue |         echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." | ||||||
|  |         continue | ||||||
|  |       fi | ||||||
|  |     else | ||||||
|  |       if [[ $gpu_count -lt $tp ]]; then | ||||||
|  |         echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||||
|  |         continue | ||||||
|  |       fi | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
|     latency_command="python3 benchmark_latency.py \ |     latency_command=" $latency_envs python3 benchmark_latency.py \ | ||||||
|       --output-json $RESULTS_FOLDER/${test_name}.json \ |       --output-json $RESULTS_FOLDER/${test_name}.json \ | ||||||
|       $latency_args" |       $latency_args" | ||||||
|  |  | ||||||
| @ -216,15 +255,24 @@ run_throughput_tests() { | |||||||
|     # get arguments |     # get arguments | ||||||
|     throughput_params=$(echo "$params" | jq -r '.parameters') |     throughput_params=$(echo "$params" | jq -r '.parameters') | ||||||
|     throughput_args=$(json2args "$throughput_params") |     throughput_args=$(json2args "$throughput_params") | ||||||
|  |     throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables') | ||||||
|  |     throughput_envs=$(json2envs "$throughput_environment_variables") | ||||||
|  |  | ||||||
|     # check if there is enough GPU to run the test |     # check if there is enough GPU to run the test | ||||||
|     tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size') |     tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size') | ||||||
|     if [[ $gpu_count -lt $tp ]]; then |     if [ "$ON_CPU" == "1" ];then | ||||||
|       echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." |       if [[ $numa_count -lt $tp ]]; then | ||||||
|       continue |         echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." | ||||||
|  |         continue | ||||||
|  |       fi | ||||||
|  |     else | ||||||
|  |       if [[ $gpu_count -lt $tp ]]; then | ||||||
|  |         echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||||
|  |         continue | ||||||
|  |       fi | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
|     throughput_command="python3 benchmark_throughput.py \ |     throughput_command=" $throughput_envs python3 benchmark_throughput.py \ | ||||||
|       --output-json $RESULTS_FOLDER/${test_name}.json \ |       --output-json $RESULTS_FOLDER/${test_name}.json \ | ||||||
|       $throughput_args" |       $throughput_args" | ||||||
|  |  | ||||||
| @ -272,18 +320,27 @@ run_serving_tests() { | |||||||
|  |  | ||||||
|     # get client and server arguments |     # get client and server arguments | ||||||
|     server_params=$(echo "$params" | jq -r '.server_parameters') |     server_params=$(echo "$params" | jq -r '.server_parameters') | ||||||
|  |     server_envs=$(echo "$params" | jq -r '.server_environment_variables') | ||||||
|     client_params=$(echo "$params" | jq -r '.client_parameters') |     client_params=$(echo "$params" | jq -r '.client_parameters') | ||||||
|     server_args=$(json2args "$server_params") |     server_args=$(json2args "$server_params") | ||||||
|  |     server_envs=$(json2envs "$server_envs") | ||||||
|     client_args=$(json2args "$client_params") |     client_args=$(json2args "$client_params") | ||||||
|     qps_list=$(echo "$params" | jq -r '.qps_list') |     qps_list=$(echo "$params" | jq -r '.qps_list') | ||||||
|     qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') |     qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') | ||||||
|     echo "Running over qps list $qps_list" |     echo "Running over qps list $qps_list" | ||||||
|  |  | ||||||
|     # check if there is enough GPU to run the test |     # check if there is enough resources to run the test | ||||||
|     tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') |     tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') | ||||||
|     if [[ $gpu_count -lt $tp ]]; then |     if [ "$ON_CPU" == "1" ];then | ||||||
|       echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." |       if [[ $numa_count -lt $tp ]]; then | ||||||
|       continue |         echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." | ||||||
|  |         continue | ||||||
|  |       fi | ||||||
|  |     else | ||||||
|  |       if [[ $gpu_count -lt $tp ]]; then | ||||||
|  |         echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||||
|  |         continue | ||||||
|  |       fi | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
|     # check if server model and client model is aligned |     # check if server model and client model is aligned | ||||||
| @ -294,23 +351,33 @@ run_serving_tests() { | |||||||
|       continue |       continue | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
|     server_command="python3 \ |     server_command="$server_envs python3 \ | ||||||
|       -m vllm.entrypoints.openai.api_server \ |       -m vllm.entrypoints.openai.api_server \ | ||||||
|       $server_args" |       $server_args" | ||||||
|  |  | ||||||
|     # run the server |     # run the server | ||||||
|     echo "Running test case $test_name" |     echo "Running test case $test_name" | ||||||
|     echo "Server command: $server_command" |     echo "Server command: $server_command" | ||||||
|     bash -c "$server_command" & |     # support remote vllm server | ||||||
|     server_pid=$! |     client_remote_args="" | ||||||
|  |     if [[ -z "${REMOTE_HOST}" ]]; then | ||||||
|     # wait until the server is alive |       bash -c "$server_command" & | ||||||
|     if wait_for_server; then |       server_pid=$! | ||||||
|       echo "" |       # wait until the server is alive | ||||||
|       echo "vllm server is up and running." |       if wait_for_server; then | ||||||
|  |         echo "" | ||||||
|  |         echo "vLLM server is up and running." | ||||||
|  |       else | ||||||
|  |         echo "" | ||||||
|  |         echo "vLLM failed to start within the timeout period." | ||||||
|  |       fi | ||||||
|     else |     else | ||||||
|       echo "" |       server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT" | ||||||
|       echo "vllm failed to start within the timeout period." |       if [[ ${REMOTE_PORT} ]]; then | ||||||
|  |         client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT " | ||||||
|  |       else | ||||||
|  |         client_remote_args=" --host=$REMOTE_HOST " | ||||||
|  |       fi | ||||||
|     fi |     fi | ||||||
|  |  | ||||||
|     # iterate over different QPS |     # iterate over different QPS | ||||||
| @ -332,7 +399,7 @@ run_serving_tests() { | |||||||
|         --result-filename ${new_test_name}.json \ |         --result-filename ${new_test_name}.json \ | ||||||
|         --request-rate $qps \ |         --request-rate $qps \ | ||||||
|         --metadata "tensor_parallel_size=$tp" \ |         --metadata "tensor_parallel_size=$tp" \ | ||||||
|         $client_args" |         $client_args $client_remote_args " | ||||||
|  |  | ||||||
|       echo "Running test case $test_name with qps $qps" |       echo "Running test case $test_name with qps $qps" | ||||||
|       echo "Client command: $client_command" |       echo "Client command: $client_command" | ||||||
| @ -360,7 +427,14 @@ run_serving_tests() { | |||||||
| } | } | ||||||
|  |  | ||||||
| main() { | main() { | ||||||
|   check_gpus |   local ARCH | ||||||
|  |   ARCH='' | ||||||
|  |   if [ "$ON_CPU" == "1" ];then | ||||||
|  |      check_cpus | ||||||
|  |      ARCH='-cpu' | ||||||
|  |   else | ||||||
|  |      check_gpus | ||||||
|  |   fi | ||||||
|   check_hf_token |   check_hf_token | ||||||
|  |  | ||||||
|   # Set to v1 to run v1 benchmark |   # Set to v1 to run v1 benchmark | ||||||
| @ -386,9 +460,9 @@ main() { | |||||||
|   QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ |   QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ | ||||||
|  |  | ||||||
|   # benchmarking |   # benchmarking | ||||||
|   run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json |   run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}" | ||||||
|   run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json |   run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}" | ||||||
|   run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json |   run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}" | ||||||
|  |  | ||||||
|   # postprocess benchmarking results |   # postprocess benchmarking results | ||||||
|   pip install tabulate pandas |   pip install tabulate pandas | ||||||
|  | |||||||
							
								
								
									
										30
									
								
								.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,30 @@ | |||||||
|  | [ | ||||||
|  |     { | ||||||
|  |         "test_name": "latency_llama8B_tp1", | ||||||
|  |         "environment_variables": { | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 1, | ||||||
|  |             "load_format": "dummy", | ||||||
|  |             "num_iters_warmup": 5, | ||||||
|  |             "num_iters": 15 | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "test_name": "latency_llama8B_tp4", | ||||||
|  |         "environment_variables": { | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 4, | ||||||
|  |             "load_format": "dummy", | ||||||
|  |             "num_iters_warmup": 5, | ||||||
|  |             "num_iters": 15 | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | ] | ||||||
							
								
								
									
										158
									
								
								.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,158 @@ | |||||||
|  | [ | ||||||
|  |     { | ||||||
|  |         "test_name": "serving_llama8B_tp1_sharegpt", | ||||||
|  |         "qps_list": [1, 4, 16, "inf"], | ||||||
|  |         "server_environment_variables": { | ||||||
|  |             "VLLM_RPC_TIMEOUT": 100000, | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "server_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 1, | ||||||
|  | 	    "dtype": "bfloat16", | ||||||
|  | 	    "distributed_executor_backend": "mp", | ||||||
|  | 	    "block_size": 128, | ||||||
|  | 	    "trust_remote_code": "", | ||||||
|  |             "disable_log_stats": "", | ||||||
|  |             "disable_log_requests": "", | ||||||
|  | 	    "enforce_eager": "", | ||||||
|  |             "load_format": "dummy" | ||||||
|  |         }, | ||||||
|  |         "client_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "backend": "vllm", | ||||||
|  |             "dataset_name": "sharegpt", | ||||||
|  |             "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||||
|  | 	    "max_concurrency": 60, | ||||||
|  |             "num_prompts": 200 | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "test_name": "serving_llama8B_tp2_sharegpt", | ||||||
|  |         "qps_list": [1, 4, 16, "inf"], | ||||||
|  |         "server_environment_variables": { | ||||||
|  |             "VLLM_RPC_TIMEOUT": 100000, | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "server_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 2, | ||||||
|  | 	    "dtype": "bfloat16", | ||||||
|  | 	    "distributed_executor_backend": "mp", | ||||||
|  | 	    "block_size": 128, | ||||||
|  | 	    "trust_remote_code": "", | ||||||
|  |             "disable_log_stats": "", | ||||||
|  |             "disable_log_requests": "", | ||||||
|  | 	    "enforce_eager": "", | ||||||
|  |             "load_format": "dummy" | ||||||
|  |         }, | ||||||
|  |         "client_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "backend": "vllm", | ||||||
|  |             "dataset_name": "sharegpt", | ||||||
|  |             "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||||
|  | 	    "max_concurrency": 60, | ||||||
|  |             "num_prompts": 200 | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "test_name": "serving_llama8B_tp4_sharegpt", | ||||||
|  |         "qps_list": [1, 4, 16, "inf"], | ||||||
|  |         "server_environment_variables": { | ||||||
|  |             "VLLM_RPC_TIMEOUT": 100000, | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "server_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 4, | ||||||
|  | 	    "dtype": "bfloat16", | ||||||
|  | 	    "distributed_executor_backend": "mp", | ||||||
|  | 	    "block_size": 128, | ||||||
|  | 	    "trust_remote_code": "", | ||||||
|  |             "disable_log_stats": "", | ||||||
|  |             "disable_log_requests": "", | ||||||
|  | 	    "enforce_eager": "", | ||||||
|  |             "load_format": "dummy" | ||||||
|  |         }, | ||||||
|  |         "client_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "backend": "vllm", | ||||||
|  |             "dataset_name": "sharegpt", | ||||||
|  |             "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||||
|  | 	    "max_concurrency": 60, | ||||||
|  |             "num_prompts": 200 | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "test_name": "serving_llama8B_tp4_random_1024_128", | ||||||
|  |         "qps_list": [1, 4, 16, "inf"], | ||||||
|  |         "server_environment_variables": { | ||||||
|  |             "VLLM_RPC_TIMEOUT": 100000, | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "server_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 4, | ||||||
|  | 	    "dtype": "bfloat16", | ||||||
|  | 	    "distributed_executor_backend": "mp", | ||||||
|  | 	    "block_size": 128, | ||||||
|  | 	    "trust_remote_code": "", | ||||||
|  | 	    "enable_chunked_prefill": "", | ||||||
|  |             "disable_log_stats": "", | ||||||
|  |             "disable_log_requests": "", | ||||||
|  | 	    "enforce_eager": "", | ||||||
|  |             "load_format": "dummy" | ||||||
|  |         }, | ||||||
|  |         "client_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "backend": "vllm", | ||||||
|  |             "dataset_name": "random", | ||||||
|  | 	    "random-input-len": 1024, | ||||||
|  | 	    "random-output-len": 128, | ||||||
|  | 	    "ignore-eos": "", | ||||||
|  | 	    "max_concurrency": 100, | ||||||
|  |             "num_prompts": 100 | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "test_name": "serving_llama8B_pp6_random_1024_128", | ||||||
|  |         "qps_list": [1, 4, 16, "inf"], | ||||||
|  |         "server_environment_variables": { | ||||||
|  |             "VLLM_RPC_TIMEOUT": 100000, | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "server_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "pipeline_parallel_size": 6, | ||||||
|  | 	    "dtype": "bfloat16", | ||||||
|  | 	    "distributed_executor_backend": "mp", | ||||||
|  | 	    "block_size": 128, | ||||||
|  | 	    "trust_remote_code": "", | ||||||
|  | 	    "enable_chunked_prefill": "", | ||||||
|  |             "disable_log_stats": "", | ||||||
|  |             "disable_log_requests": "", | ||||||
|  | 	    "enforce_eager": "", | ||||||
|  |             "load_format": "dummy" | ||||||
|  |         }, | ||||||
|  |         "client_parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "backend": "vllm", | ||||||
|  |             "dataset_name": "random", | ||||||
|  | 	    "random-input-len": 1024, | ||||||
|  | 	    "random-output-len": 128, | ||||||
|  | 	    "ignore-eos": "", | ||||||
|  | 	    "max_concurrency": 100, | ||||||
|  |             "num_prompts": 100 | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | ] | ||||||
| @ -0,0 +1,32 @@ | |||||||
|  | [ | ||||||
|  |     { | ||||||
|  |         "test_name": "throughput_llama8B_tp1", | ||||||
|  |         "environment_variables": { | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 1, | ||||||
|  |             "load_format": "dummy", | ||||||
|  |             "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||||
|  |             "num_prompts": 200, | ||||||
|  |             "backend": "vllm" | ||||||
|  |         } | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |         "test_name": "throughput_llama8B_tp4", | ||||||
|  |         "environment_variables": { | ||||||
|  | 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||||
|  | 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||||
|  |         }, | ||||||
|  |         "parameters": { | ||||||
|  |             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||||
|  |             "tensor_parallel_size": 4, | ||||||
|  |             "load_format": "dummy", | ||||||
|  |             "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||||
|  |             "num_prompts": 200, | ||||||
|  |             "backend": "vllm" | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | ] | ||||||
| @ -52,7 +52,7 @@ steps: | |||||||
|       queue: cpu_queue_postmerge |       queue: cpu_queue_postmerge | ||||||
|     commands: |     commands: | ||||||
|       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" |       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" | ||||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." |       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." | ||||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" |       - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" | ||||||
|  |  | ||||||
|   - label: "Annotate release workflow" |   - label: "Annotate release workflow" | ||||||
| @ -101,7 +101,8 @@ steps: | |||||||
|       queue: cpu_queue_postmerge |       queue: cpu_queue_postmerge | ||||||
|     commands: |     commands: | ||||||
|       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" |       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" | ||||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." |       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." | ||||||
|  |       - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" | ||||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" |       - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" | ||||||
|     env: |     env: | ||||||
|       DOCKER_BUILDKIT: "1" |       DOCKER_BUILDKIT: "1" | ||||||
| @ -117,6 +118,7 @@ steps: | |||||||
|     commands: |     commands: | ||||||
|       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" |       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" | ||||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." |       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." | ||||||
|  |       - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest" | ||||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" |       - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" | ||||||
|     env: |     env: | ||||||
|       DOCKER_BUILDKIT: "1" |       DOCKER_BUILDKIT: "1" | ||||||
|  | |||||||
| @ -107,10 +107,9 @@ fi | |||||||
|  |  | ||||||
| if [[ $commands == *" kernels/attention"* ]]; then | if [[ $commands == *" kernels/attention"* ]]; then | ||||||
|   commands="${commands} \ |   commands="${commands} \ | ||||||
|   --ignore=kernels/attention/stest_attention_selector.py \ |   --ignore=kernels/attention/test_attention_selector.py \ | ||||||
|   --ignore=kernels/attention/test_blocksparse_attention.py \ |   --ignore=kernels/attention/test_blocksparse_attention.py \ | ||||||
|   --ignore=kernels/attention/test_encoder_decoder_attn.py \ |   --ignore=kernels/attention/test_encoder_decoder_attn.py \ | ||||||
|   --ignore=kernels/attention/test_attention_selector.py \ |  | ||||||
|   --ignore=kernels/attention/test_flash_attn.py \ |   --ignore=kernels/attention/test_flash_attn.py \ | ||||||
|   --ignore=kernels/attention/test_flashinfer.py \ |   --ignore=kernels/attention/test_flashinfer.py \ | ||||||
|   --ignore=kernels/attention/test_prefix_prefill.py \ |   --ignore=kernels/attention/test_prefix_prefill.py \ | ||||||
|  | |||||||
| @ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE | |||||||
| numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . | numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . | ||||||
|  |  | ||||||
| # Run the image, setting --shm-size=4g for tensor parallel. | # Run the image, setting --shm-size=4g for tensor parallel. | ||||||
| docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" | docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" | ||||||
| docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 | docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 | ||||||
|  |  | ||||||
| function cpu_tests() { | function cpu_tests() { | ||||||
|   set -e |   set -e | ||||||
|   export NUMA_NODE=$2 |   export NUMA_NODE=$2 | ||||||
|  |  | ||||||
|  |   # list packages | ||||||
|  |   docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " | ||||||
|  |     set -e | ||||||
|  |     pip list" | ||||||
|  |  | ||||||
|  |   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||||
|  |     set -e | ||||||
|  |     pip list" | ||||||
|  |  | ||||||
|   # offline inference |   # offline inference | ||||||
|   docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " |   docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " | ||||||
|     set -e |     set -e | ||||||
| @ -39,9 +48,16 @@ function cpu_tests() { | |||||||
|   # Run basic model test |   # Run basic model test | ||||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " |   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||||
|     set -e |     set -e | ||||||
|     pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model |     # Note: disable until supports V1 | ||||||
|     pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model |     # pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model | ||||||
|     pytest -v -s tests/models/language/generation -m cpu_model |     # pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model | ||||||
|  |  | ||||||
|  |     # Note: disable Bart until supports V1 | ||||||
|  |     pytest -v -s tests/models/language/generation -m cpu_model \ | ||||||
|  |                 --ignore=tests/models/language/generation/test_bart.py | ||||||
|  |     VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \ | ||||||
|  |                 --ignore=tests/models/language/generation/test_bart.py | ||||||
|  |  | ||||||
|     pytest -v -s tests/models/language/pooling -m cpu_model |     pytest -v -s tests/models/language/pooling -m cpu_model | ||||||
|     pytest -v -s tests/models/multimodal/generation \ |     pytest -v -s tests/models/multimodal/generation \ | ||||||
|                 --ignore=tests/models/multimodal/generation/test_mllama.py \ |                 --ignore=tests/models/multimodal/generation/test_mllama.py \ | ||||||
| @ -52,27 +68,21 @@ function cpu_tests() { | |||||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " |   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||||
|     set -e |     set -e | ||||||
|     pytest -s -v \ |     pytest -s -v \ | ||||||
|     tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ |     tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]"  | ||||||
|     tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" |  | ||||||
|  |  | ||||||
|  |   # Note: disable it until supports V1 | ||||||
|   # Run AWQ test |   # Run AWQ test | ||||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " |   # docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||||
|     set -e |   #   set -e | ||||||
|     VLLM_USE_V1=0 pytest -s -v \ |   #   VLLM_USE_V1=0 pytest -s -v \ | ||||||
|     tests/quantization/test_ipex_quant.py" |   #   tests/quantization/test_ipex_quant.py" | ||||||
|  |  | ||||||
|   # Run chunked-prefill and prefix-cache test |  | ||||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " |  | ||||||
|     set -e |  | ||||||
|     pytest -s -v -k cpu_model \ |  | ||||||
|     tests/basic_correctness/test_chunked_prefill.py"   |  | ||||||
|  |  | ||||||
|   # online serving |   # online serving | ||||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " |   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||||
|     set -e |     set -e | ||||||
|     python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &  |     python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &  | ||||||
|     timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 |     timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 | ||||||
|     python3 benchmarks/benchmark_serving.py \ |     VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \ | ||||||
|       --backend vllm \ |       --backend vllm \ | ||||||
|       --dataset-name random \ |       --dataset-name random \ | ||||||
|       --model facebook/opt-125m \ |       --model facebook/opt-125m \ | ||||||
| @ -89,4 +99,4 @@ function cpu_tests() { | |||||||
|  |  | ||||||
| # All of CPU tests are expected to be finished less than 40 mins. | # All of CPU tests are expected to be finished less than 40 mins. | ||||||
| export -f cpu_tests | export -f cpu_tests | ||||||
| timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" | timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" | ||||||
|  | |||||||
| @ -2,10 +2,32 @@ | |||||||
|  |  | ||||||
| # This script build the CPU docker image and run the offline inference inside the container. | # This script build the CPU docker image and run the offline inference inside the container. | ||||||
| # It serves a sanity check for compilation and basic model usage. | # It serves a sanity check for compilation and basic model usage. | ||||||
| set -ex | set -exuo pipefail | ||||||
|  |  | ||||||
| # Try building the docker image | # Try building the docker image | ||||||
| docker build -t hpu-test-env -f docker/Dockerfile.hpu . | cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - . | ||||||
|  | FROM gaudi-base-image:latest | ||||||
|  |  | ||||||
|  | COPY ./ /workspace/vllm | ||||||
|  |  | ||||||
|  | WORKDIR /workspace/vllm | ||||||
|  |  | ||||||
|  | ENV no_proxy=localhost,127.0.0.1 | ||||||
|  | ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true | ||||||
|  |  | ||||||
|  | RUN VLLM_TARGET_DEVICE=empty pip install . | ||||||
|  | RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git | ||||||
|  |  | ||||||
|  | # install development dependencies (for testing) | ||||||
|  | RUN python3 -m pip install -e tests/vllm_test_utils | ||||||
|  |  | ||||||
|  | WORKDIR /workspace/ | ||||||
|  |  | ||||||
|  | RUN git clone https://github.com/vllm-project/vllm-gaudi.git | ||||||
|  |  | ||||||
|  | RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks | ||||||
|  |  | ||||||
|  | EOF | ||||||
|  |  | ||||||
| # Setup cleanup | # Setup cleanup | ||||||
| # certain versions of HPU software stack have a bug that can | # certain versions of HPU software stack have a bug that can | ||||||
| @ -14,13 +36,21 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu . | |||||||
| # functions, while other platforms only need one remove_docker_container | # functions, while other platforms only need one remove_docker_container | ||||||
| # function. | # function. | ||||||
| EXITCODE=1 | EXITCODE=1 | ||||||
| remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; } | remove_docker_containers() { docker rm -f hpu-plugin-v1-test || true; } | ||||||
| remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; } | trap 'remove_docker_containers; exit $EXITCODE;' EXIT | ||||||
| trap remove_docker_containers_and_exit EXIT |  | ||||||
| remove_docker_containers | remove_docker_containers | ||||||
|  |  | ||||||
| # Run the image and launch offline inference | echo "Running HPU plugin v1 test" | ||||||
| docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m | docker run --rm --runtime=habana --name=hpu-plugin-v1-test --network=host \ | ||||||
| docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2 |   -e HABANA_VISIBLE_DEVICES=all \ | ||||||
|  |   hpu-plugin-v1-test-env \ | ||||||
|  |   /bin/bash "/workspace/vllm-gaudi/tests/upstream_tests/ci_tests.sh" | ||||||
|  |  | ||||||
| EXITCODE=$? | EXITCODE=$? | ||||||
|  | if [ $EXITCODE -eq 0 ]; then | ||||||
|  |   echo "Test with basic model passed" | ||||||
|  | else | ||||||
|  |   echo "Test with basic model FAILED with exit code: $EXITCODE" >&2 | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | # The trap will handle the container removal and final exit. | ||||||
| @ -54,10 +54,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \ | |||||||
|        --name "${container_name}" \ |        --name "${container_name}" \ | ||||||
|        ${image_name} \ |        ${image_name} \ | ||||||
|        /bin/bash -c " |        /bin/bash -c " | ||||||
|  |             set -e; # Exit on first error | ||||||
|             python3 /workspace/vllm/examples/offline_inference/neuron.py; |             python3 /workspace/vllm/examples/offline_inference/neuron.py; | ||||||
|             python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; |             python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; | ||||||
|             for f in /workspace/vllm/tests/neuron/2_core/*.py; do |             for f in /workspace/vllm/tests/neuron/2_core/*.py; do | ||||||
|                 echo 'Running test file: '$f; |                 echo \"Running test file: \$f\"; | ||||||
|                 python3 -m pytest \$f -v --capture=tee-sys; |                 python3 -m pytest \$f -v --capture=tee-sys; | ||||||
|             done |             done | ||||||
|        " |        " | ||||||
| @ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \ | |||||||
|     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" |     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" | ||||||
| run_and_track_test 15 "test_spmd_model_weight_loading.py" \ | run_and_track_test 15 "test_spmd_model_weight_loading.py" \ | ||||||
|     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" |     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" | ||||||
|  | run_and_track_test 16 "test_kv_cache_update_kernel.py" \ | ||||||
|  |     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" | ||||||
|  |  | ||||||
| # After all tests have been attempted, exit with the overall status. | # After all tests have been attempted, exit with the overall status. | ||||||
| if [ "$overall_script_exit_code" -ne 0 ]; then | if [ "$overall_script_exit_code" -ne 0 ]; then | ||||||
|  | |||||||
| @ -11,8 +11,8 @@ container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head | |||||||
| docker build -t ${image_name} -f docker/Dockerfile.xpu . | docker build -t ${image_name} -f docker/Dockerfile.xpu . | ||||||
|  |  | ||||||
| # Setup cleanup | # Setup cleanup | ||||||
| remove_docker_container() {  | remove_docker_container() { | ||||||
|   docker rm -f "${container_name}" || true;  |   docker rm -f "${container_name}" || true; | ||||||
|   docker image rm -f "${image_name}" || true; |   docker image rm -f "${image_name}" || true; | ||||||
|   docker system prune -f || true; |   docker system prune -f || true; | ||||||
| } | } | ||||||
| @ -26,6 +26,9 @@ docker run \ | |||||||
|     --name "${container_name}" \ |     --name "${container_name}" \ | ||||||
|     "${image_name}" \ |     "${image_name}" \ | ||||||
|     sh -c ' |     sh -c ' | ||||||
|     VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m |     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager | ||||||
|     VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 |     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray | ||||||
|  |     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp | ||||||
|  |     cd tests | ||||||
|  |     pytest -v -s v1/core | ||||||
| ' | ' | ||||||
|  | |||||||
| @ -4,8 +4,8 @@ CONTAINER_NAME=vllm-tpu | |||||||
|  |  | ||||||
| # vllm config | # vllm config | ||||||
| MODEL=meta-llama/Llama-3.1-8B-Instruct | MODEL=meta-llama/Llama-3.1-8B-Instruct | ||||||
| MAX_NUM_SEQS=512 | MAX_NUM_SEQS=256 | ||||||
| MAX_NUM_BATCHED_TOKENS=512 | MAX_NUM_BATCHED_TOKENS=1024 | ||||||
| TENSOR_PARALLEL_SIZE=1 | TENSOR_PARALLEL_SIZE=1 | ||||||
| MAX_MODEL_LEN=2048 | MAX_MODEL_LEN=2048 | ||||||
| DOWNLOAD_DIR=/mnt/disks/persist | DOWNLOAD_DIR=/mnt/disks/persist | ||||||
|  | |||||||
| @ -22,16 +22,6 @@ trap remove_docker_container EXIT | |||||||
| # Remove the container that might not be cleaned up in the previous run. | # Remove the container that might not be cleaned up in the previous run. | ||||||
| remove_docker_container | remove_docker_container | ||||||
|  |  | ||||||
| # Build docker image. |  | ||||||
| # TODO: build the image outside the script and share the image with other |  | ||||||
| # tpu test if building time is too long. |  | ||||||
| DOCKER_BUILDKIT=1 docker build \ |  | ||||||
|   --build-arg max_jobs=16 \ |  | ||||||
|   --build-arg USE_SCCACHE=1 \ |  | ||||||
|   --build-arg GIT_REPO_CHECK=0 \ |  | ||||||
|   --tag vllm/vllm-tpu-bm \ |  | ||||||
|   --progress plain -f docker/Dockerfile.tpu . |  | ||||||
|  |  | ||||||
| LOG_ROOT=$(mktemp -d) | LOG_ROOT=$(mktemp -d) | ||||||
| # If mktemp fails, set -e will cause the script to exit. | # If mktemp fails, set -e will cause the script to exit. | ||||||
| echo "Results will be stored in: $LOG_ROOT" | echo "Results will be stored in: $LOG_ROOT" | ||||||
| @ -68,7 +58,7 @@ docker run \ | |||||||
|  |  | ||||||
| echo "run script..." | echo "run script..." | ||||||
| echo | echo | ||||||
| docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh" | docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/tpu/run_bm.sh" | ||||||
|  |  | ||||||
| echo "copy result back..." | echo "copy result back..." | ||||||
| VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt | VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt | ||||||
|  | |||||||
							
								
								
									
										14
									
								
								.buildkite/scripts/tpu/quantized_v6e_1.env
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								.buildkite/scripts/tpu/quantized_v6e_1.env
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | |||||||
|  | # Environment config | ||||||
|  | TEST_NAME=llama8bw8a8 | ||||||
|  | CONTAINER_NAME=vllm-tpu | ||||||
|  |  | ||||||
|  | # vllm config | ||||||
|  | MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 | ||||||
|  | MAX_NUM_SEQS=128 | ||||||
|  | MAX_NUM_BATCHED_TOKENS=1024 | ||||||
|  | TENSOR_PARALLEL_SIZE=1 | ||||||
|  | MAX_MODEL_LEN=2048 | ||||||
|  | DOWNLOAD_DIR=/mnt/disks/persist | ||||||
|  | EXPECTED_THROUGHPUT=10.0 | ||||||
|  | INPUT_LEN=1800 | ||||||
|  | OUTPUT_LEN=128 | ||||||
| @ -41,6 +41,16 @@ steps: | |||||||
|   # TODO: add `--strict` once warnings in docstrings are fixed |   # TODO: add `--strict` once warnings in docstrings are fixed | ||||||
|   - mkdocs build |   - mkdocs build | ||||||
|  |  | ||||||
|  | - label: Pytorch Nightly Dependency Override Check # 2min | ||||||
|  |   # if this test fails, it means the nightly torch version is not compatible with some | ||||||
|  |   # of the dependencies. Please check the error message and add the package to whitelist | ||||||
|  |   # in /vllm/tools/generate_nightly_torch_test.py | ||||||
|  |   soft_fail: true | ||||||
|  |   source_file_dependencies: | ||||||
|  |   - requirements/nightly_torch_test.txt | ||||||
|  |   commands: | ||||||
|  |   - bash standalone_tests/pytorch_nightly_dependency.sh | ||||||
|  |  | ||||||
| - label: Async Engine, Inputs, Utils, Worker Test # 24min | - label: Async Engine, Inputs, Utils, Worker Test # 24min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental] | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
| @ -89,7 +99,7 @@ steps: | |||||||
|   - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py |   - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py | ||||||
|  |  | ||||||
| - label: Chunked Prefill Test | - label: Chunked Prefill Test | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental, amdproduction] | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|   - vllm/ |   - vllm/ | ||||||
|   - tests/basic_correctness/test_chunked_prefill |   - tests/basic_correctness/test_chunked_prefill | ||||||
| @ -107,7 +117,7 @@ steps: | |||||||
|   commands: |   commands: | ||||||
|   - pytest -v -s core |   - pytest -v -s core | ||||||
|  |  | ||||||
| - label: Entrypoints Test # 40min | - label: Entrypoints Test (LLM) # 40min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental] | ||||||
|   working_dir: "/vllm-workspace/tests" |   working_dir: "/vllm-workspace/tests" | ||||||
|   fast_check: true |   fast_check: true | ||||||
| @ -115,8 +125,6 @@ steps: | |||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|   - vllm/ |   - vllm/ | ||||||
|   - tests/entrypoints/llm |   - tests/entrypoints/llm | ||||||
|   - tests/entrypoints/openai |  | ||||||
|   - tests/entrypoints/test_chat_utils |  | ||||||
|   - tests/entrypoints/offline_mode |   - tests/entrypoints/offline_mode | ||||||
|   commands: |   commands: | ||||||
|   - export VLLM_WORKER_MULTIPROC_METHOD=spawn |   - export VLLM_WORKER_MULTIPROC_METHOD=spawn | ||||||
| @ -125,9 +133,21 @@ steps: | |||||||
|   - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process |   - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process | ||||||
|   - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process |   - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process | ||||||
|   - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process |   - VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process | ||||||
|  |   - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests | ||||||
|  |  | ||||||
|  | - label: Entrypoints Test (API Server) # 40min | ||||||
|  |   mirror_hardwares: [amdexperimental] | ||||||
|  |   working_dir: "/vllm-workspace/tests" | ||||||
|  |   fast_check: true | ||||||
|  |   torch_nightly: true | ||||||
|  |   source_file_dependencies: | ||||||
|  |   - vllm/ | ||||||
|  |   - tests/entrypoints/openai | ||||||
|  |   - tests/entrypoints/test_chat_utils | ||||||
|  |   commands: | ||||||
|  |   - export VLLM_WORKER_MULTIPROC_METHOD=spawn | ||||||
|   - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ |   - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ | ||||||
|   - pytest -v -s entrypoints/test_chat_utils.py |   - pytest -v -s entrypoints/test_chat_utils.py | ||||||
|   - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests |  | ||||||
|  |  | ||||||
| - label: Distributed Tests (4 GPUs) # 10min | - label: Distributed Tests (4 GPUs) # 10min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental] | ||||||
| @ -145,6 +165,7 @@ steps: | |||||||
|   - examples/offline_inference/rlhf_colocate.py |   - examples/offline_inference/rlhf_colocate.py | ||||||
|   - tests/examples/offline_inference/data_parallel.py |   - tests/examples/offline_inference/data_parallel.py | ||||||
|   - tests/v1/test_async_llm_dp.py |   - tests/v1/test_async_llm_dp.py | ||||||
|  |   - tests/v1/test_external_lb_dp.py | ||||||
|   - tests/v1/engine/test_engine_core_client.py |   - tests/v1/engine/test_engine_core_client.py | ||||||
|   commands: |   commands: | ||||||
|   # test with tp=2 and external_dp=2 |   # test with tp=2 and external_dp=2 | ||||||
| @ -153,8 +174,9 @@ steps: | |||||||
|   # test with tp=2 and pp=2 |   # test with tp=2 and pp=2 | ||||||
|   - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py |   - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py | ||||||
|   # test with internal dp |   # test with internal dp | ||||||
|   - python3 ../examples/offline_inference/data_parallel.py |   - python3 ../examples/offline_inference/data_parallel.py --enforce-eager | ||||||
|   - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py |   - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py | ||||||
|  |   - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py | ||||||
|   - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp |   - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp | ||||||
|   - pytest -v -s distributed/test_utils.py |   - pytest -v -s distributed/test_utils.py | ||||||
|   - pytest -v -s compile/test_basic_correctness.py |   - pytest -v -s compile/test_basic_correctness.py | ||||||
| @ -168,6 +190,23 @@ steps: | |||||||
|   - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py |   - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py | ||||||
|   - popd |   - popd | ||||||
|  |  | ||||||
|  | - label: EPLB Algorithm Test | ||||||
|  |   working_dir: "/vllm-workspace/tests" | ||||||
|  |   source_file_dependencies: | ||||||
|  |   - vllm/distributed/eplb | ||||||
|  |   - tests/distributed/test_eplb_algo.py | ||||||
|  |   commands: | ||||||
|  |   - pytest -v -s distributed/test_eplb_algo.py | ||||||
|  |  | ||||||
|  | - label: EPLB Execution Test # 5min | ||||||
|  |   working_dir: "/vllm-workspace/tests" | ||||||
|  |   num_gpus: 4 | ||||||
|  |   source_file_dependencies: | ||||||
|  |   - vllm/distributed/eplb | ||||||
|  |   - tests/distributed/test_eplb_execute.py | ||||||
|  |   commands: | ||||||
|  |   - pytest -v -s distributed/test_eplb_execute.py | ||||||
|  |  | ||||||
| - label: Metrics, Tracing Test # 10min | - label: Metrics, Tracing Test # 10min | ||||||
|   mirror_hardwares: [amdexperimental, amdproduction] |   mirror_hardwares: [amdexperimental, amdproduction] | ||||||
|   num_gpus: 2 |   num_gpus: 2 | ||||||
| @ -177,13 +216,18 @@ steps: | |||||||
|   - tests/tracing |   - tests/tracing | ||||||
|   commands: |   commands: | ||||||
|   - pytest -v -s metrics |   - pytest -v -s metrics | ||||||
|  |   - "pip install \ | ||||||
|  |       'opentelemetry-sdk>=1.26.0' \ | ||||||
|  |       'opentelemetry-api>=1.26.0' \ | ||||||
|  |       'opentelemetry-exporter-otlp>=1.26.0' \ | ||||||
|  |       'opentelemetry-semantic-conventions-ai>=0.4.1'" | ||||||
|   - pytest -v -s tracing |   - pytest -v -s tracing | ||||||
|  |  | ||||||
| ##### fast check tests  ##### | ##### fast check tests  ##### | ||||||
| #####  1 GPU test  ##### | #####  1 GPU test  ##### | ||||||
|  |  | ||||||
| - label: Regression Test # 5min | - label: Regression Test # 5min | ||||||
|   mirror_hardwares: [amdexperimental, amdproduction] |   mirror_hardwares: [amdexperimental] | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|   - vllm/ |   - vllm/ | ||||||
|   - tests/test_regression |   - tests/test_regression | ||||||
| @ -193,7 +237,7 @@ steps: | |||||||
|   working_dir: "/vllm-workspace/tests" # optional |   working_dir: "/vllm-workspace/tests" # optional | ||||||
|  |  | ||||||
| - label: Engine Test # 10min | - label: Engine Test # 10min | ||||||
|   mirror_hardwares: [amdexperimental, amdproduction] |   mirror_hardwares: [amdexperimental] | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|   - vllm/ |   - vllm/ | ||||||
|   - tests/engine |   - tests/engine | ||||||
| @ -248,7 +292,7 @@ steps: | |||||||
|     - python3 offline_inference/llm_engine_example.py |     - python3 offline_inference/llm_engine_example.py | ||||||
|     - python3 offline_inference/audio_language.py --seed 0 |     - python3 offline_inference/audio_language.py --seed 0 | ||||||
|     - python3 offline_inference/vision_language.py --seed 0 |     - python3 offline_inference/vision_language.py --seed 0 | ||||||
|     - python3 offline_inference/vision_language_embedding.py --seed 0 |     - python3 offline_inference/vision_language_pooling.py --seed 0 | ||||||
|     - python3 offline_inference/vision_language_multi_image.py --seed 0 |     - python3 offline_inference/vision_language_multi_image.py --seed 0 | ||||||
|     - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors |     - VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors | ||||||
|     - python3 offline_inference/encoder_decoder.py |     - python3 offline_inference/encoder_decoder.py | ||||||
| @ -266,6 +310,15 @@ steps: | |||||||
|   commands: |   commands: | ||||||
|     - pytest -v -s prefix_caching |     - pytest -v -s prefix_caching | ||||||
|  |  | ||||||
|  |  | ||||||
|  | - label: Platform Tests (CUDA) | ||||||
|  |   mirror_hardwares: [amdexperimental] | ||||||
|  |   source_file_dependencies: | ||||||
|  |   - vllm/ | ||||||
|  |   - tests/cuda | ||||||
|  |   commands: | ||||||
|  |     - pytest -v -s cuda/test_cuda_context.py | ||||||
|  |  | ||||||
| - label: Samplers Test # 36min | - label: Samplers Test # 36min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental] | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
| @ -297,7 +350,7 @@ steps: | |||||||
|   parallelism: 4 |   parallelism: 4 | ||||||
|  |  | ||||||
| - label: PyTorch Compilation Unit Tests | - label: PyTorch Compilation Unit Tests | ||||||
|   mirror_hardwares: [amdexperimental, amdproduction] |   mirror_hardwares: [amdexperimental] | ||||||
|   torch_nightly: true |   torch_nightly: true | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|     - vllm/ |     - vllm/ | ||||||
| @ -305,6 +358,7 @@ steps: | |||||||
|   commands: |   commands: | ||||||
|     - pytest -v -s compile/test_pass_manager.py |     - pytest -v -s compile/test_pass_manager.py | ||||||
|     - pytest -v -s compile/test_fusion.py |     - pytest -v -s compile/test_fusion.py | ||||||
|  |     - pytest -v -s compile/test_fusion_attn.py | ||||||
|     - pytest -v -s compile/test_silu_mul_quant_fusion.py |     - pytest -v -s compile/test_silu_mul_quant_fusion.py | ||||||
|     - pytest -v -s compile/test_sequence_parallelism.py |     - pytest -v -s compile/test_sequence_parallelism.py | ||||||
|     - pytest -v -s compile/test_async_tp.py |     - pytest -v -s compile/test_async_tp.py | ||||||
| @ -378,7 +432,7 @@ steps: | |||||||
|     - pytest -v -s kernels/mamba |     - pytest -v -s kernels/mamba | ||||||
|  |  | ||||||
| - label: Tensorizer Test # 11min | - label: Tensorizer Test # 11min | ||||||
|   mirror_hardwares: [amdexperimental, amdproduction] |   mirror_hardwares: [amdexperimental] | ||||||
|   soft_fail: true |   soft_fail: true | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|   - vllm/model_executor/model_loader |   - vllm/model_executor/model_loader | ||||||
| @ -470,7 +524,7 @@ steps: | |||||||
| #####  models test  ##### | #####  models test  ##### | ||||||
|  |  | ||||||
| - label: Basic Models Test # 24min | - label: Basic Models Test # 24min | ||||||
|   mirror_hardwares: [amdexperimental, amdproduction] |   mirror_hardwares: [amdexperimental] | ||||||
|   torch_nightly: true |   torch_nightly: true | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|   - vllm/ |   - vllm/ | ||||||
| @ -494,6 +548,17 @@ steps: | |||||||
|     - pip freeze | grep -E 'torch' |     - pip freeze | grep -E 'torch' | ||||||
|     - pytest -v -s models/language -m core_model |     - pytest -v -s models/language -m core_model | ||||||
|  |  | ||||||
|  | - label: Language Models Test (Hybrid) # 35 min | ||||||
|  |   mirror_hardwares: [amdexperimental] | ||||||
|  |   torch_nightly: true | ||||||
|  |   source_file_dependencies: | ||||||
|  |   - vllm/ | ||||||
|  |   - tests/models/language/generation | ||||||
|  |   commands: | ||||||
|  |     # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. | ||||||
|  |     - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' | ||||||
|  |     - pytest -v -s models/language/generation -m hybrid_model | ||||||
|  |  | ||||||
| - label: Language Models Test (Extended Generation) # 1hr20min | - label: Language Models Test (Extended Generation) # 1hr20min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental] | ||||||
|   optional: true |   optional: true | ||||||
| @ -503,7 +568,7 @@ steps: | |||||||
|   commands: |   commands: | ||||||
|     # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. |     # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. | ||||||
|     - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' |     - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' | ||||||
|     - pytest -v -s models/language/generation -m 'not core_model' |     - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' | ||||||
|  |  | ||||||
| - label: Language Models Test (Extended Pooling)  # 36min | - label: Language Models Test (Extended Pooling)  # 36min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental] | ||||||
| @ -548,7 +613,7 @@ steps: | |||||||
|     - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' |     - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' | ||||||
|  |  | ||||||
| - label: Multi-Modal Models Test (Extended) 3 | - label: Multi-Modal Models Test (Extended) 3 | ||||||
|   mirror_hardwares: [amdexperimental, amdproduction] |   mirror_hardwares: [amdexperimental] | ||||||
|   optional: true |   optional: true | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
|   - vllm/ |   - vllm/ | ||||||
| @ -575,6 +640,18 @@ steps: | |||||||
|     # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py |     # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py | ||||||
|     # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* |     # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* | ||||||
|  |  | ||||||
|  | - label: Transformers Nightly Models Test | ||||||
|  |   working_dir: "/vllm-workspace/" | ||||||
|  |   optional: true | ||||||
|  |   commands: | ||||||
|  |     - pip install --upgrade git+https://github.com/huggingface/transformers | ||||||
|  |     - pytest -v -s tests/models/test_initialization.py | ||||||
|  |     - pytest -v -s tests/models/multimodal/processing/ | ||||||
|  |     - pytest -v -s tests/models/multimodal/test_mapping.py | ||||||
|  |     - python3 examples/offline_inference/basic/chat.py | ||||||
|  |     - python3 examples/offline_inference/audio_language.py --model-type whisper | ||||||
|  |     - python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl | ||||||
|  |  | ||||||
| #####  1 GPU test  ##### | #####  1 GPU test  ##### | ||||||
| #####  multi gpus test  ##### | #####  multi gpus test  ##### | ||||||
|  |  | ||||||
| @ -600,13 +677,18 @@ steps: | |||||||
|   - vllm/executor/ |   - vllm/executor/ | ||||||
|   - vllm/model_executor/models/ |   - vllm/model_executor/models/ | ||||||
|   - tests/distributed/ |   - tests/distributed/ | ||||||
|  |   - tests/examples/offline_inference/data_parallel.py | ||||||
|   commands: |   commands: | ||||||
|   - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) |   - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) | ||||||
|     - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' |     - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' | ||||||
|  |     - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' | ||||||
|  |     - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code | ||||||
|     - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py |     - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py | ||||||
|     - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py |     - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py | ||||||
|   - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) |   - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) | ||||||
|     - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' |     - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' | ||||||
|  |     - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' | ||||||
|  |     - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code | ||||||
|  |  | ||||||
| - label: Distributed Tests (2 GPUs) # 40min | - label: Distributed Tests (2 GPUs) # 40min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental] | ||||||
| @ -624,10 +706,12 @@ steps: | |||||||
|   - vllm/worker/model_runner.py |   - vllm/worker/model_runner.py | ||||||
|   - entrypoints/llm/test_collective_rpc.py |   - entrypoints/llm/test_collective_rpc.py | ||||||
|   - tests/v1/test_async_llm_dp.py |   - tests/v1/test_async_llm_dp.py | ||||||
|  |   - tests/v1/test_external_lb_dp.py | ||||||
|   - tests/v1/entrypoints/openai/test_multi_api_servers.py |   - tests/v1/entrypoints/openai/test_multi_api_servers.py | ||||||
|   - vllm/v1/engine/ |   - vllm/v1/engine/ | ||||||
|   commands: |   commands: | ||||||
|   - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py |   - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py | ||||||
|  |   - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py | ||||||
|   - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py |   - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py | ||||||
|   - pytest -v -s entrypoints/llm/test_collective_rpc.py |   - pytest -v -s entrypoints/llm/test_collective_rpc.py | ||||||
|   - pytest -v -s ./compile/test_basic_correctness.py |   - pytest -v -s ./compile/test_basic_correctness.py | ||||||
| @ -669,7 +753,7 @@ steps: | |||||||
|   - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins |   - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins | ||||||
|  |  | ||||||
| - label: Multi-step Tests (4 GPUs) # 36min | - label: Multi-step Tests (4 GPUs) # 36min | ||||||
|   mirror_hardwares: [amdexperimental] |   mirror_hardwares: [amdexperimental, amdproduction] | ||||||
|   working_dir: "/vllm-workspace/tests" |   working_dir: "/vllm-workspace/tests" | ||||||
|   num_gpus: 4 |   num_gpus: 4 | ||||||
|   source_file_dependencies: |   source_file_dependencies: | ||||||
| @ -730,7 +814,7 @@ steps: | |||||||
|     - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt |     - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt | ||||||
|  |  | ||||||
| - label: Weight Loading Multiple GPU Test - Large Models # optional | - label: Weight Loading Multiple GPU Test - Large Models # optional | ||||||
|   mirror_hardwares: [amdexperimental]  |   mirror_hardwares: [amdexperimental] | ||||||
|   working_dir: "/vllm-workspace/tests" |   working_dir: "/vllm-workspace/tests" | ||||||
|   num_gpus: 2 |   num_gpus: 2 | ||||||
|   gpu: a100 |   gpu: a100 | ||||||
|  | |||||||
							
								
								
									
										6
									
								
								.gemini/config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.gemini/config.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | |||||||
|  | # https://developers.google.com/gemini-code-assist/docs/customize-gemini-behavior-github | ||||||
|  | have_fun: false  # Just review the code | ||||||
|  | code_review: | ||||||
|  |   comment_severity_threshold: HIGH  # Reduce quantity of comments | ||||||
|  |   pull_request_opened: | ||||||
|  |     summary: false  # Don't summarize the PR in a separate comment | ||||||
							
								
								
									
										7
									
								
								.github/CODEOWNERS
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/CODEOWNERS
									
									
									
									
										vendored
									
									
								
							| @ -16,7 +16,12 @@ | |||||||
| /vllm/lora @jeejeelee | /vllm/lora @jeejeelee | ||||||
| /vllm/reasoning @aarnphm | /vllm/reasoning @aarnphm | ||||||
| /vllm/entrypoints @aarnphm | /vllm/entrypoints @aarnphm | ||||||
| CMakeLists.txt @tlrmchlsmth | /vllm/compilation @zou3519 @youkaichao | ||||||
|  | CMakeLists.txt @tlrmchlsmth @LucasWilkinson | ||||||
|  |  | ||||||
|  | # Any change to the VllmConfig changes can have a large user-facing impact, | ||||||
|  | # so spam a lot of people | ||||||
|  | /vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor | ||||||
|  |  | ||||||
| # vLLM V1 | # vLLM V1 | ||||||
| /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat | /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat | ||||||
|  | |||||||
							
								
								
									
										83
									
								
								.github/mergify.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										83
									
								
								.github/mergify.yml
									
									
									
									
										vendored
									
									
								
							| @ -27,6 +27,22 @@ pull_request_rules: | |||||||
|       add: |       add: | ||||||
|         - ci/build |         - ci/build | ||||||
|  |  | ||||||
|  | - name: label-deepseek | ||||||
|  |   description: Automatically apply deepseek label | ||||||
|  |   conditions: | ||||||
|  |     - or: | ||||||
|  |       - files~=^examples/.*deepseek.*\.py | ||||||
|  |       - files~=^tests/.*deepseek.*\.py | ||||||
|  |       - files~=^vllm/entrypoints/openai/tool_parsers/.*deepseek.*\.py | ||||||
|  |       - files~=^vllm/model_executor/models/.*deepseek.*\.py | ||||||
|  |       - files~=^vllm/reasoning/.*deepseek.*\.py | ||||||
|  |       - files~=^vllm/transformers_utils/.*deepseek.*\.py | ||||||
|  |       - title~=(?i)DeepSeek | ||||||
|  |   actions: | ||||||
|  |     label: | ||||||
|  |       add: | ||||||
|  |         - deepseek | ||||||
|  |  | ||||||
| - name: label-frontend | - name: label-frontend | ||||||
|   description: Automatically apply frontend label |   description: Automatically apply frontend label | ||||||
|   conditions: |   conditions: | ||||||
| @ -45,6 +61,7 @@ pull_request_rules: | |||||||
|       - files~=^vllm/entrypoints/openai/tool_parsers/llama.*\.py |       - files~=^vllm/entrypoints/openai/tool_parsers/llama.*\.py | ||||||
|       - files~=^vllm/model_executor/models/.*llama.*\.py |       - files~=^vllm/model_executor/models/.*llama.*\.py | ||||||
|       - files~=^vllm/transformers_utils/configs/.*llama.*\.py |       - files~=^vllm/transformers_utils/configs/.*llama.*\.py | ||||||
|  |       - title~=(?i)llama | ||||||
|   actions: |   actions: | ||||||
|     label: |     label: | ||||||
|       add: |       add: | ||||||
| @ -57,14 +74,70 @@ pull_request_rules: | |||||||
|       - files~=^vllm/multimodal/ |       - files~=^vllm/multimodal/ | ||||||
|       - files~=^tests/multimodal/ |       - files~=^tests/multimodal/ | ||||||
|       - files~=^tests/models/multimodal/ |       - files~=^tests/models/multimodal/ | ||||||
|       - files~=^tests/models/*/audio_language/ |  | ||||||
|       - files~=^tests/models/*/vision_language/ |  | ||||||
|       - files=tests/models/test_vision.py |       - files=tests/models/test_vision.py | ||||||
|   actions: |   actions: | ||||||
|     label: |     label: | ||||||
|       add: |       add: | ||||||
|         - multi-modality |         - multi-modality | ||||||
|  |  | ||||||
|  | - name: label-new-model | ||||||
|  |   description: Automatically apply new-model label | ||||||
|  |   conditions: | ||||||
|  |     - and: | ||||||
|  |       - files~=^vllm/model_executor/models/ | ||||||
|  |       - files=vllm/model_executor/models/registry.py | ||||||
|  |   actions: | ||||||
|  |     label: | ||||||
|  |       add: | ||||||
|  |         - new-model | ||||||
|  |  | ||||||
|  | - name: label-performance | ||||||
|  |   description: Automatically apply performance label | ||||||
|  |   conditions: | ||||||
|  |     - or: | ||||||
|  |       - files~=^benchmarks/ | ||||||
|  |       - files~=^vllm/benchmarks/ | ||||||
|  |       - files~=^tests/benchmarks/ | ||||||
|  |       - files~=^\.buildkite/nightly-benchmarks/ | ||||||
|  |   actions: | ||||||
|  |     label: | ||||||
|  |       add: | ||||||
|  |         - performance | ||||||
|  |  | ||||||
|  | - name: label-qwen | ||||||
|  |   description: Automatically apply qwen label | ||||||
|  |   conditions: | ||||||
|  |     - or: | ||||||
|  |       - files~=^examples/.*qwen.*\.py | ||||||
|  |       - files~=^tests/.*qwen.*\.py | ||||||
|  |       - files~=^vllm/model_executor/models/.*qwen.*\.py | ||||||
|  |       - files~=^vllm/reasoning/.*qwen.*\.py | ||||||
|  |       - title~=(?i)Qwen | ||||||
|  |   actions: | ||||||
|  |     label: | ||||||
|  |       add: | ||||||
|  |         - qwen | ||||||
|  |  | ||||||
|  | - name: label-rocm | ||||||
|  |   description: Automatically apply rocm label | ||||||
|  |   conditions: | ||||||
|  |     - or: | ||||||
|  |       - files~=^csrc/rocm/ | ||||||
|  |       - files~=^docker/Dockerfile.rocm | ||||||
|  |       - files~=^requirements/rocm.*\.txt | ||||||
|  |       - files~=^vllm/attention/backends/rocm.*\.py | ||||||
|  |       - files~=^vllm/attention/ops/rocm.*\.py | ||||||
|  |       - files~=^vllm/model_executor/layers/fused_moe/rocm.*\.py | ||||||
|  |       - files~=^vllm/v1/attention/backends/mla/rocm.*\.py | ||||||
|  |       - files~=^tests/kernels/.*_rocm.*\.py | ||||||
|  |       - files=vllm/platforms/rocm.py | ||||||
|  |       - title~=(?i)AMD | ||||||
|  |       - title~=(?i)ROCm | ||||||
|  |   actions: | ||||||
|  |     label: | ||||||
|  |       add: | ||||||
|  |         - rocm | ||||||
|  |  | ||||||
| - name: label-structured-output | - name: label-structured-output | ||||||
|   description: Automatically apply structured-output label |   description: Automatically apply structured-output label | ||||||
|   conditions: |   conditions: | ||||||
| @ -92,8 +165,14 @@ pull_request_rules: | |||||||
|   conditions: |   conditions: | ||||||
|     - or: |     - or: | ||||||
|       - files~=^vllm/spec_decode/ |       - files~=^vllm/spec_decode/ | ||||||
|  |       - files~=^vllm/v1/spec_decode/ | ||||||
|       - files=vllm/model_executor/layers/spec_decode_base_sampler.py |       - files=vllm/model_executor/layers/spec_decode_base_sampler.py | ||||||
|       - files~=^tests/spec_decode/ |       - files~=^tests/spec_decode/ | ||||||
|  |       - files~=^tests/v1/spec_decode/ | ||||||
|  |       - files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py | ||||||
|  |       - files~=^vllm/model_executor/models/.*eagle.*\.py | ||||||
|  |       - files=vllm/model_executor/models/mlp_speculator.py | ||||||
|  |       - files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py | ||||||
|   actions: |   actions: | ||||||
|     label: |     label: | ||||||
|       add: |       add: | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/lint-and-deploy.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/lint-and-deploy.yaml
									
									
									
									
										vendored
									
									
								
							| @ -68,7 +68,7 @@ jobs: | |||||||
|           export AWS_ACCESS_KEY_ID=minioadmin |           export AWS_ACCESS_KEY_ID=minioadmin | ||||||
|           export AWS_SECRET_ACCESS_KEY=minioadmin |           export AWS_SECRET_ACCESS_KEY=minioadmin | ||||||
|           sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & |           sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & | ||||||
|           helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" |           helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" | ||||||
|  |  | ||||||
|       - name: curl test |       - name: curl test | ||||||
|         run: | |         run: | | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -146,6 +146,7 @@ venv.bak/ | |||||||
|  |  | ||||||
| # mkdocs documentation | # mkdocs documentation | ||||||
| /site | /site | ||||||
|  | docs/argparse | ||||||
| docs/examples | docs/examples | ||||||
|  |  | ||||||
| # mypy | # mypy | ||||||
| @ -200,5 +201,5 @@ benchmarks/**/*.json | |||||||
| actionlint | actionlint | ||||||
| shellcheck*/ | shellcheck*/ | ||||||
|  |  | ||||||
| # Ingore moe/marlin_moe gen code | # Ignore moe/marlin_moe gen code | ||||||
| csrc/moe/marlin_moe_wna16/kernel_* | csrc/moe/marlin_moe_wna16/kernel_* | ||||||
|  | |||||||
| @ -20,12 +20,10 @@ repos: | |||||||
|     args: [--output-format, github, --fix] |     args: [--output-format, github, --fix] | ||||||
|   - id: ruff-format |   - id: ruff-format | ||||||
|     files: ^(.buildkite|benchmarks|examples)/.* |     files: ^(.buildkite|benchmarks|examples)/.* | ||||||
| - repo: https://github.com/codespell-project/codespell | - repo: https://github.com/crate-ci/typos | ||||||
|   rev: v2.4.1 |   rev: v1.34.0 | ||||||
|   hooks: |   hooks: | ||||||
|   - id: codespell |   - id: typos | ||||||
|     additional_dependencies: ['tomli'] |  | ||||||
|     args: ['--toml', 'pyproject.toml'] |  | ||||||
| - repo: https://github.com/PyCQA/isort | - repo: https://github.com/PyCQA/isort | ||||||
|   rev: 6.0.1 |   rev: 6.0.1 | ||||||
|   hooks: |   hooks: | ||||||
| @ -55,6 +53,11 @@ repos: | |||||||
|       files: ^requirements/test\.(in|txt)$ |       files: ^requirements/test\.(in|txt)$ | ||||||
| - repo: local | - repo: local | ||||||
|   hooks: |   hooks: | ||||||
|  |   - id: format-torch-nightly-test | ||||||
|  |     name: reformat nightly_torch_test.txt to be in sync with test.in | ||||||
|  |     language: python | ||||||
|  |     entry: python tools/generate_nightly_torch_test.py | ||||||
|  |     files: ^requirements/test\.(in|txt)$ | ||||||
|   - id: mypy-local |   - id: mypy-local | ||||||
|     name: Run mypy for local Python installation |     name: Run mypy for local Python installation | ||||||
|     entry: tools/mypy.sh 0 "local" |     entry: tools/mypy.sh 0 "local" | ||||||
| @ -117,6 +120,11 @@ repos: | |||||||
|     entry: python tools/check_spdx_header.py |     entry: python tools/check_spdx_header.py | ||||||
|     language: python |     language: python | ||||||
|     types: [python] |     types: [python] | ||||||
|  |   - id: check-root-lazy-imports | ||||||
|  |     name: Check root lazy imports | ||||||
|  |     entry: python tools/check_init_lazy_imports.py | ||||||
|  |     language: python | ||||||
|  |     types: [python] | ||||||
|   - id: check-filenames |   - id: check-filenames | ||||||
|     name: Check for spaces in all filenames |     name: Check for spaces in all filenames | ||||||
|     entry: bash |     entry: bash | ||||||
| @ -145,10 +153,24 @@ repos: | |||||||
|     types: [python] |     types: [python] | ||||||
|     pass_filenames: false |     pass_filenames: false | ||||||
|     additional_dependencies: [regex] |     additional_dependencies: [regex] | ||||||
|  |   - id: check-pickle-imports | ||||||
|  |     name: Prevent new pickle/cloudpickle imports | ||||||
|  |     entry: python tools/check_pickle_imports.py | ||||||
|  |     language: python | ||||||
|  |     types: [python] | ||||||
|  |     pass_filenames: false | ||||||
|  |     additional_dependencies: [pathspec, regex] | ||||||
|  |   - id: validate-config | ||||||
|  |     name: Validate configuration has default values and that each field has a docstring | ||||||
|  |     entry: python tools/validate_config.py | ||||||
|  |     language: python | ||||||
|  |     types: [python] | ||||||
|  |     pass_filenames: true | ||||||
|  |     files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py | ||||||
|   # Keep `suggestion` last |   # Keep `suggestion` last | ||||||
|   - id: suggestion |   - id: suggestion | ||||||
|     name: Suggestion |     name: Suggestion | ||||||
|     entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."' |     entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=<hook-id>."' | ||||||
|     language: system |     language: system | ||||||
|     verbose: true |     verbose: true | ||||||
|     pass_filenames: false |     pass_filenames: false | ||||||
|  | |||||||
							
								
								
									
										117
									
								
								CMakeLists.txt
									
									
									
									
									
								
							
							
						
						
									
										117
									
								
								CMakeLists.txt
									
									
									
									
									
								
							| @ -171,7 +171,6 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|   list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") |   list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
|  |  | ||||||
| # | # | ||||||
| # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. | # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. | ||||||
| # setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. | # setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. | ||||||
| @ -232,7 +231,6 @@ endif() | |||||||
|  |  | ||||||
| set(VLLM_EXT_SRC | set(VLLM_EXT_SRC | ||||||
|   "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" |   "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" | ||||||
|   "csrc/mamba/causal_conv1d/causal_conv1d.cu" |  | ||||||
|   "csrc/cache_kernels.cu" |   "csrc/cache_kernels.cu" | ||||||
|   "csrc/attention/paged_attention_v1.cu" |   "csrc/attention/paged_attention_v1.cu" | ||||||
|   "csrc/attention/paged_attention_v2.cu" |   "csrc/attention/paged_attention_v2.cu" | ||||||
| @ -259,7 +257,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|   SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") |   SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") | ||||||
|  |  | ||||||
|   # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. |   # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. | ||||||
|   set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use") |   set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use") | ||||||
|  |  | ||||||
|   # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided |   # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided | ||||||
|   if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) |   if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) | ||||||
| @ -393,7 +391,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|   # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require |   # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require | ||||||
|   # CUDA 12.0 or later |   # CUDA 12.0 or later | ||||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") |   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") | ||||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS) |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) | ||||||
|     set(SRCS |     set(SRCS | ||||||
|        "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" |        "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" | ||||||
|        "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" |        "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" | ||||||
| @ -409,7 +407,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|     list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") |     list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") | ||||||
|     message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") |     message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") | ||||||
|   else() |   else() | ||||||
|     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS) |     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) | ||||||
|       message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " |       message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " | ||||||
|                      "not >= 12.0, we recommend upgrading to CUDA 12.0 or " |                      "not >= 12.0, we recommend upgrading to CUDA 12.0 or " | ||||||
|                      "later if you intend on running FP8 quantized models on " |                      "later if you intend on running FP8 quantized models on " | ||||||
| @ -420,10 +418,40 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|     endif() |     endif() | ||||||
|   endif() |   endif() | ||||||
|  |  | ||||||
|   # The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require |  | ||||||
|  |   # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require | ||||||
|   # CUDA 12.8 or later |   # CUDA 12.8 or later | ||||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}") |   cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") | ||||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) | ||||||
|  |     set(SRCS | ||||||
|  |       "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" | ||||||
|  |       "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" | ||||||
|  |     ) | ||||||
|  |     set_gencode_flags_for_srcs( | ||||||
|  |       SRCS "${SRCS}" | ||||||
|  |       CUDA_ARCHS "${SCALED_MM_ARCHS}") | ||||||
|  |     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||||
|  |     list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") | ||||||
|  |     # Let scaled_mm_c2x know it doesn't need to build these arches | ||||||
|  |     list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") | ||||||
|  |     message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") | ||||||
|  |   else() | ||||||
|  |     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) | ||||||
|  |       message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " | ||||||
|  |                      "not >= 12.8, we recommend upgrading to CUDA 12.8 or " | ||||||
|  |                      "later if you intend on running FP8 quantized models on " | ||||||
|  |                      "Blackwell.") | ||||||
|  |     else() | ||||||
|  |       message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " | ||||||
|  |                      "in CUDA target architectures") | ||||||
|  |     endif() | ||||||
|  |   endif() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |   # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) | ||||||
|  |   # require CUDA 12.8 or later | ||||||
|  |   cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}") | ||||||
|  |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) | ||||||
|     set(SRCS |     set(SRCS | ||||||
|       "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" |       "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" | ||||||
|       "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" |       "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" | ||||||
| @ -438,7 +466,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|     list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") |     list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") | ||||||
|     message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") |     message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") | ||||||
|   else() |   else() | ||||||
|     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) |     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) | ||||||
|       message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " |       message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " | ||||||
|                      "not >= 12.8, we recommend upgrading to CUDA 12.8 or " |                      "not >= 12.8, we recommend upgrading to CUDA 12.8 or " | ||||||
|                      "later if you intend on running FP8 quantized models on " |                      "later if you intend on running FP8 quantized models on " | ||||||
| @ -481,7 +509,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|   # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor |   # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor | ||||||
|   # require CUDA 12.2 or later (and only work on Hopper). |   # require CUDA 12.2 or later (and only work on Hopper). | ||||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") |   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") | ||||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS) |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) | ||||||
|     set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") |     set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") | ||||||
|     set_gencode_flags_for_srcs( |     set_gencode_flags_for_srcs( | ||||||
|       SRCS "${SRCS}" |       SRCS "${SRCS}" | ||||||
| @ -490,7 +518,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|     list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") |     list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") | ||||||
|     message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") |     message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") | ||||||
|   else() |   else() | ||||||
|     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS) |     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) | ||||||
|       message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " |       message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " | ||||||
|                      "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " |                      "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " | ||||||
|                      "if you intend on running FP8 sparse quantized models on Hopper.") |                      "if you intend on running FP8 sparse quantized models on Hopper.") | ||||||
| @ -502,7 +530,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|  |  | ||||||
|   # FP4 Archs and flags |   # FP4 Archs and flags | ||||||
|   cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") |   cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") | ||||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) | ||||||
|     set(SRCS |     set(SRCS | ||||||
|       "csrc/quantization/fp4/nvfp4_quant_kernels.cu" |       "csrc/quantization/fp4/nvfp4_quant_kernels.cu" | ||||||
|       "csrc/quantization/fp4/nvfp4_experts_quant.cu" |       "csrc/quantization/fp4/nvfp4_experts_quant.cu" | ||||||
| @ -513,6 +541,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|       CUDA_ARCHS "${FP4_ARCHS}") |       CUDA_ARCHS "${FP4_ARCHS}") | ||||||
|     list(APPEND VLLM_EXT_SRC "${SRCS}") |     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||||
|     list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") |     list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") | ||||||
|  |     list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") | ||||||
|     message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") |     message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") | ||||||
|   else() |   else() | ||||||
|     message(STATUS "Not building NVFP4 as no compatible archs were found.") |     message(STATUS "Not building NVFP4 as no compatible archs were found.") | ||||||
| @ -522,9 +551,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|  |  | ||||||
|   # CUTLASS MLA Archs and flags |   # CUTLASS MLA Archs and flags | ||||||
|   cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") |   cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") | ||||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS) |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) | ||||||
|     set(SRCS |     set(SRCS | ||||||
|       "csrc/attention/mla/cutlass_mla_kernels.cu") |       "csrc/attention/mla/cutlass_mla_kernels.cu" | ||||||
|  |       "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") | ||||||
|     set_gencode_flags_for_srcs( |     set_gencode_flags_for_srcs( | ||||||
|       SRCS "${SRCS}" |       SRCS "${SRCS}" | ||||||
|       CUDA_ARCHS "${MLA_ARCHS}") |       CUDA_ARCHS "${MLA_ARCHS}") | ||||||
| @ -542,13 +572,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|  |  | ||||||
|   # CUTLASS MoE kernels |   # CUTLASS MoE kernels | ||||||
|  |  | ||||||
|   # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works |   # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works | ||||||
|   # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled |   # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled | ||||||
|   # if it's possible to compile MoE kernels that use its output. |   # if it's possible to compile MoE kernels that use its output. | ||||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") |   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") | ||||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) | ||||||
|     set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" |     set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") | ||||||
|              "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") |  | ||||||
|     set_gencode_flags_for_srcs( |     set_gencode_flags_for_srcs( | ||||||
|       SRCS "${SRCS}" |       SRCS "${SRCS}" | ||||||
|       CUDA_ARCHS "${SCALED_MM_ARCHS}") |       CUDA_ARCHS "${SCALED_MM_ARCHS}") | ||||||
| @ -562,6 +591,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|                      "if you intend on running FP8 quantized MoE models on Hopper.") |                      "if you intend on running FP8 quantized MoE models on Hopper.") | ||||||
|     else() |     else() | ||||||
|       message(STATUS "Not building grouped_mm_c3x as no compatible archs found " |       message(STATUS "Not building grouped_mm_c3x as no compatible archs found " | ||||||
|  |                      "in CUDA target architectures.") | ||||||
|  |     endif() | ||||||
|  |   endif() | ||||||
|  |  | ||||||
|  |   # moe_data.cu is used by all CUTLASS MoE kernels. | ||||||
|  |   cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") | ||||||
|  |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) | ||||||
|  |     set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") | ||||||
|  |     set_gencode_flags_for_srcs( | ||||||
|  |       SRCS "${SRCS}" | ||||||
|  |       CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") | ||||||
|  |     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||||
|  |     message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") | ||||||
|  |   else() | ||||||
|  |     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) | ||||||
|  |       message(STATUS "Not building moe_data as CUDA Compiler version is " | ||||||
|  |                      "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " | ||||||
|  |                      "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") | ||||||
|  |     else() | ||||||
|  |       message(STATUS "Not building moe_data as no compatible archs found " | ||||||
|  |                      "in CUDA target architectures.") | ||||||
|  |     endif() | ||||||
|  |   endif() | ||||||
|  |    | ||||||
|  |   cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") | ||||||
|  |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) | ||||||
|  |     set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") | ||||||
|  |     set_gencode_flags_for_srcs( | ||||||
|  |       SRCS "${SRCS}" | ||||||
|  |       CUDA_ARCHS "${SCALED_MM_ARCHS}") | ||||||
|  |     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||||
|  |     list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") | ||||||
|  |     message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}") | ||||||
|  |   else() | ||||||
|  |     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) | ||||||
|  |       message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " | ||||||
|  |                      "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " | ||||||
|  |                      "if you intend on running FP8 quantized MoE models on Blackwell.") | ||||||
|  |     else() | ||||||
|  |       message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " | ||||||
|                      "in CUDA target architectures") |                      "in CUDA target architectures") | ||||||
|     endif() |     endif() | ||||||
|   endif() |   endif() | ||||||
| @ -572,7 +641,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|   # The machete kernels only work on hopper and require CUDA 12.0 or later. |   # The machete kernels only work on hopper and require CUDA 12.0 or later. | ||||||
|   # Only build Machete kernels if we are building for something compatible with sm90a |   # Only build Machete kernels if we are building for something compatible with sm90a | ||||||
|   cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") |   cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") | ||||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS) |   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS) | ||||||
|     # |     # | ||||||
|     # For the Machete kernels we automatically generate sources for various |     # For the Machete kernels we automatically generate sources for various | ||||||
|     # preselected input type pairs and schedules. |     # preselected input type pairs and schedules. | ||||||
| @ -624,7 +693,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
|  |  | ||||||
|     message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") |     message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") | ||||||
|   else() |   else() | ||||||
|     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 |     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 | ||||||
|         AND MACHETE_ARCHS) |         AND MACHETE_ARCHS) | ||||||
|       message(STATUS "Not building Machete kernels as CUDA Compiler version is " |       message(STATUS "Not building Machete kernels as CUDA Compiler version is " | ||||||
|                      "not >= 12.0, we recommend upgrading to CUDA 12.0 or " |                      "not >= 12.0, we recommend upgrading to CUDA 12.0 or " | ||||||
| @ -638,6 +707,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | |||||||
| # if CUDA endif | # if CUDA endif | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
|  | if (VLLM_GPU_LANG STREQUAL "HIP") | ||||||
|  |   # Add QuickReduce kernels | ||||||
|  |   list(APPEND VLLM_EXT_SRC | ||||||
|  |     "csrc/custom_quickreduce.cu" | ||||||
|  |   ) | ||||||
|  | # if ROCM endif | ||||||
|  | endif() | ||||||
|  |  | ||||||
| message(STATUS "Enabling C extension.") | message(STATUS "Enabling C extension.") | ||||||
| define_gpu_extension_target( | define_gpu_extension_target( | ||||||
|   _C |   _C | ||||||
|  | |||||||
| @ -63,13 +63,11 @@ vLLM is fast with: | |||||||
| - Speculative decoding | - Speculative decoding | ||||||
| - Chunked prefill | - Chunked prefill | ||||||
|  |  | ||||||
| **Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script. |  | ||||||
|  |  | ||||||
| vLLM is flexible and easy to use with: | vLLM is flexible and easy to use with: | ||||||
|  |  | ||||||
| - Seamless integration with popular Hugging Face models | - Seamless integration with popular Hugging Face models | ||||||
| - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more | - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more | ||||||
| - Tensor parallelism and pipeline parallelism support for distributed inference | - Tensor, pipeline, data and expert parallelism support for distributed inference | ||||||
| - Streaming outputs | - Streaming outputs | ||||||
| - OpenAI-compatible API server | - OpenAI-compatible API server | ||||||
| - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron | - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron | ||||||
| @ -154,11 +152,13 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs | |||||||
|  |  | ||||||
| ## Contact Us | ## Contact Us | ||||||
|  |  | ||||||
|  | <!-- --8<-- [start:contact-us] --> | ||||||
| - For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions) | - For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions) | ||||||
| - For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai) | - For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai) | ||||||
| - coordinating contributions and development, please use [Slack](https://slack.vllm.ai) | - For coordinating contributions and development, please use [Slack](https://slack.vllm.ai) | ||||||
| - For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature | - For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature | ||||||
| - For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu) | - For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu) | ||||||
|  | <!-- --8<-- [end:contact-us] --> | ||||||
|  |  | ||||||
| ## Media Kit | ## Media Kit | ||||||
|  |  | ||||||
|  | |||||||
| @ -4,7 +4,7 @@ This README guides you through running benchmark tests with the extensive | |||||||
| datasets supported on vLLM. It’s a living document, updated as new features and datasets | datasets supported on vLLM. It’s a living document, updated as new features and datasets | ||||||
| become available. | become available. | ||||||
|  |  | ||||||
| ## Dataset Overview | **Dataset Overview** | ||||||
|  |  | ||||||
| <table style="width:100%; border-collapse: collapse;"> | <table style="width:100%; border-collapse: collapse;"> | ||||||
|   <thead> |   <thead> | ||||||
| @ -82,7 +82,10 @@ become available. | |||||||
| **Note**: HuggingFace dataset's `dataset-name` should be set to `hf` | **Note**: HuggingFace dataset's `dataset-name` should be set to `hf` | ||||||
|  |  | ||||||
| --- | --- | ||||||
| ## Example - Online Benchmark | <details> | ||||||
|  | <summary><b>🚀 Example - Online Benchmark</b></summary> | ||||||
|  |  | ||||||
|  | <br/> | ||||||
|  |  | ||||||
| First start serving your model | First start serving your model | ||||||
|  |  | ||||||
| @ -130,7 +133,8 @@ P99 ITL (ms):                            8.39 | |||||||
| ================================================== | ================================================== | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### Custom Dataset | **Custom Dataset** | ||||||
|  |  | ||||||
| If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl | If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
| @ -162,7 +166,7 @@ python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detaile | |||||||
|  |  | ||||||
| You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. | You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. | ||||||
|  |  | ||||||
| ### VisionArena Benchmark for Vision Language Models | **VisionArena Benchmark for Vision Language Models** | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| # need a model with vision capability here | # need a model with vision capability here | ||||||
| @ -180,7 +184,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ | |||||||
|   --num-prompts 1000 |   --num-prompts 1000 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### InstructCoder Benchmark with Speculative Decoding | **InstructCoder Benchmark with Speculative Decoding** | ||||||
|  |  | ||||||
| ``` bash | ``` bash | ||||||
| VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ | VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ | ||||||
| @ -197,7 +201,7 @@ python3 benchmarks/benchmark_serving.py \ | |||||||
|     --num-prompts 2048 |     --num-prompts 2048 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### Other HuggingFaceDataset Examples | **Other HuggingFaceDataset Examples** | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests | vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests | ||||||
| @ -251,7 +255,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ | |||||||
|     --num-prompts 80 |     --num-prompts 80 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### Running With Sampling Parameters | **Running With Sampling Parameters** | ||||||
|  |  | ||||||
| When using OpenAI-compatible backends such as `vllm`, optional sampling | When using OpenAI-compatible backends such as `vllm`, optional sampling | ||||||
| parameters can be specified. Example client command: | parameters can be specified. Example client command: | ||||||
| @ -269,8 +273,27 @@ python3 vllm/benchmarks/benchmark_serving.py \ | |||||||
|   --num-prompts 10 |   --num-prompts 10 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| --- | **Running With Ramp-Up Request Rate** | ||||||
| ## Example - Offline Throughput Benchmark |  | ||||||
|  | The benchmark tool also supports ramping up the request rate over the | ||||||
|  | duration of the benchmark run. This can be useful for stress testing the | ||||||
|  | server or finding the maximum throughput that it can handle, given some latency budget. | ||||||
|  |  | ||||||
|  | Two ramp-up strategies are supported: | ||||||
|  | - `linear`: Increases the request rate linearly from a start value to an end value. | ||||||
|  | - `exponential`: Increases the request rate exponentially. | ||||||
|  |  | ||||||
|  | The following arguments can be used to control the ramp-up: | ||||||
|  | - `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). | ||||||
|  | - `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. | ||||||
|  | - `--ramp-up-end-rps`: The request rate at the end of the benchmark. | ||||||
|  |  | ||||||
|  | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><b>📈 Example - Offline Throughput Benchmark</b></summary> | ||||||
|  |  | ||||||
|  | <br/> | ||||||
|  |  | ||||||
| ```bash | ```bash | ||||||
| python3 vllm/benchmarks/benchmark_throughput.py \ | python3 vllm/benchmarks/benchmark_throughput.py \ | ||||||
| @ -288,7 +311,7 @@ Total num prompt tokens:  5014 | |||||||
| Total num output tokens:  1500 | Total num output tokens:  1500 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### VisionArena Benchmark for Vision Language Models | **VisionArena Benchmark for Vision Language Models** | ||||||
|  |  | ||||||
| ``` bash | ``` bash | ||||||
| python3 vllm/benchmarks/benchmark_throughput.py \ | python3 vllm/benchmarks/benchmark_throughput.py \ | ||||||
| @ -308,7 +331,7 @@ Total num prompt tokens:  14527 | |||||||
| Total num output tokens:  1280 | Total num output tokens:  1280 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### InstructCoder Benchmark with Speculative Decoding | **InstructCoder Benchmark with Speculative Decoding** | ||||||
|  |  | ||||||
| ``` bash | ``` bash | ||||||
| VLLM_WORKER_MULTIPROC_METHOD=spawn \ | VLLM_WORKER_MULTIPROC_METHOD=spawn \ | ||||||
| @ -332,7 +355,7 @@ Total num prompt tokens:  261136 | |||||||
| Total num output tokens:  204800 | Total num output tokens:  204800 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### Other HuggingFaceDataset Examples | **Other HuggingFaceDataset Examples** | ||||||
|  |  | ||||||
| **`lmms-lab/LLaVA-OneVision-Data`** | **`lmms-lab/LLaVA-OneVision-Data`** | ||||||
|  |  | ||||||
| @ -371,7 +394,7 @@ python3 benchmarks/benchmark_throughput.py \ | |||||||
|   --num-prompts 10 |   --num-prompts 10 | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ### Benchmark with LoRA Adapters | **Benchmark with LoRA Adapters** | ||||||
|  |  | ||||||
| ``` bash | ``` bash | ||||||
| # download dataset | # download dataset | ||||||
| @ -387,3 +410,196 @@ python3 vllm/benchmarks/benchmark_throughput.py \ | |||||||
|   --enable-lora \ |   --enable-lora \ | ||||||
|   --lora-path yard1/llama-2-7b-sql-lora-test |   --lora-path yard1/llama-2-7b-sql-lora-test | ||||||
|   ``` |   ``` | ||||||
|  |  | ||||||
|  | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><b>🛠️ Example - Structured Output Benchmark</b></summary> | ||||||
|  |  | ||||||
|  | <br/> | ||||||
|  |  | ||||||
|  | Benchmark the performance of structured output generation (JSON, grammar, regex). | ||||||
|  |  | ||||||
|  | **Server Setup** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **JSON Schema Benchmark** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_serving_structured_output.py \ | ||||||
|  |   --backend vllm \ | ||||||
|  |   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||||
|  |   --dataset json \ | ||||||
|  |   --structured-output-ratio 1.0 \ | ||||||
|  |   --request-rate 10 \ | ||||||
|  |   --num-prompts 1000 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Grammar-based Generation Benchmark** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_serving_structured_output.py \ | ||||||
|  |   --backend vllm \ | ||||||
|  |   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||||
|  |   --dataset grammar \ | ||||||
|  |   --structure-type grammar \ | ||||||
|  |   --request-rate 10 \ | ||||||
|  |   --num-prompts 1000 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Regex-based Generation Benchmark** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_serving_structured_output.py \ | ||||||
|  |   --backend vllm \ | ||||||
|  |   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||||
|  |   --dataset regex \ | ||||||
|  |   --request-rate 10 \ | ||||||
|  |   --num-prompts 1000 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Choice-based Generation Benchmark** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_serving_structured_output.py \ | ||||||
|  |   --backend vllm \ | ||||||
|  |   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||||
|  |   --dataset choice \ | ||||||
|  |   --request-rate 10 \ | ||||||
|  |   --num-prompts 1000 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **XGrammar Benchmark Dataset** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_serving_structured_output.py \ | ||||||
|  |   --backend vllm \ | ||||||
|  |   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||||
|  |   --dataset xgrammar_bench \ | ||||||
|  |   --request-rate 10 \ | ||||||
|  |   --num-prompts 1000 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><b>📚 Example - Long Document QA Benchmark</b></summary> | ||||||
|  |  | ||||||
|  | <br/> | ||||||
|  |  | ||||||
|  | Benchmark the performance of long document question-answering with prefix caching. | ||||||
|  |  | ||||||
|  | **Basic Long Document QA Test** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --enable-prefix-caching \ | ||||||
|  |   --num-documents 16 \ | ||||||
|  |   --document-length 2000 \ | ||||||
|  |   --output-len 50 \ | ||||||
|  |   --repeat-count 5 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Different Repeat Modes** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | # Random mode (default) - shuffle prompts randomly | ||||||
|  | python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --enable-prefix-caching \ | ||||||
|  |   --num-documents 8 \ | ||||||
|  |   --document-length 3000 \ | ||||||
|  |   --repeat-count 3 \ | ||||||
|  |   --repeat-mode random | ||||||
|  |  | ||||||
|  | # Tile mode - repeat entire prompt list in sequence | ||||||
|  | python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --enable-prefix-caching \ | ||||||
|  |   --num-documents 8 \ | ||||||
|  |   --document-length 3000 \ | ||||||
|  |   --repeat-count 3 \ | ||||||
|  |   --repeat-mode tile | ||||||
|  |  | ||||||
|  | # Interleave mode - repeat each prompt consecutively | ||||||
|  | python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --enable-prefix-caching \ | ||||||
|  |   --num-documents 8 \ | ||||||
|  |   --document-length 3000 \ | ||||||
|  |   --repeat-count 3 \ | ||||||
|  |   --repeat-mode interleave | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><b>🗂️ Example - Prefix Caching Benchmark</b></summary> | ||||||
|  |  | ||||||
|  | <br/> | ||||||
|  |  | ||||||
|  | Benchmark the efficiency of automatic prefix caching. | ||||||
|  |  | ||||||
|  | **Fixed Prompt with Prefix Caching** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_prefix_caching.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --enable-prefix-caching \ | ||||||
|  |   --num-prompts 1 \ | ||||||
|  |   --repeat-count 100 \ | ||||||
|  |   --input-length-range 128:256 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **ShareGPT Dataset with Prefix Caching** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | # download dataset | ||||||
|  | # wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json | ||||||
|  |  | ||||||
|  | python3 benchmarks/benchmark_prefix_caching.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ | ||||||
|  |   --enable-prefix-caching \ | ||||||
|  |   --num-prompts 20 \ | ||||||
|  |   --repeat-count 5 \ | ||||||
|  |   --input-length-range 128:256 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | </details> | ||||||
|  |  | ||||||
|  | <details> | ||||||
|  | <summary><b>⚡ Example - Request Prioritization Benchmark</b></summary> | ||||||
|  |  | ||||||
|  | <br/> | ||||||
|  |  | ||||||
|  | Benchmark the performance of request prioritization in vLLM. | ||||||
|  |  | ||||||
|  | **Basic Prioritization Test** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_prioritization.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --input-len 128 \ | ||||||
|  |   --output-len 64 \ | ||||||
|  |   --num-prompts 100 \ | ||||||
|  |   --scheduling-policy priority | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | **Multiple Sequences per Prompt** | ||||||
|  |  | ||||||
|  | ```bash | ||||||
|  | python3 benchmarks/benchmark_prioritization.py \ | ||||||
|  |   --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |   --input-len 128 \ | ||||||
|  |   --output-len 64 \ | ||||||
|  |   --num-prompts 100 \ | ||||||
|  |   --scheduling-policy priority \ | ||||||
|  |   --n 2 | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | </details> | ||||||
|  | |||||||
| @ -10,6 +10,7 @@ | |||||||
| # 3. Set variables (ALL REQUIRED) | # 3. Set variables (ALL REQUIRED) | ||||||
| #   BASE: your directory for vllm repo | #   BASE: your directory for vllm repo | ||||||
| #   MODEL: the model served by vllm | #   MODEL: the model served by vllm | ||||||
|  | #   SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support. | ||||||
| #   TP: ways of tensor parallelism | #   TP: ways of tensor parallelism | ||||||
| #   DOWNLOAD_DIR: directory to download and load model weights. | #   DOWNLOAD_DIR: directory to download and load model weights. | ||||||
| #   INPUT_LEN: request input len | #   INPUT_LEN: request input len | ||||||
| @ -34,6 +35,7 @@ | |||||||
| TAG=$(date +"%Y_%m_%d_%H_%M") | TAG=$(date +"%Y_%m_%d_%H_%M") | ||||||
| BASE="" | BASE="" | ||||||
| MODEL="meta-llama/Llama-3.1-8B-Instruct" | MODEL="meta-llama/Llama-3.1-8B-Instruct" | ||||||
|  | SYSTEM="TPU" | ||||||
| TP=1 | TP=1 | ||||||
| DOWNLOAD_DIR="" | DOWNLOAD_DIR="" | ||||||
| INPUT_LEN=4000 | INPUT_LEN=4000 | ||||||
| @ -45,12 +47,15 @@ NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" | |||||||
|  |  | ||||||
| LOG_FOLDER="$BASE/auto-benchmark/$TAG" | LOG_FOLDER="$BASE/auto-benchmark/$TAG" | ||||||
| RESULT="$LOG_FOLDER/result.txt" | RESULT="$LOG_FOLDER/result.txt" | ||||||
|  | PROFILE_PATH="$LOG_FOLDER/profile" | ||||||
|  |  | ||||||
| echo "result file: $RESULT" | echo "result file: $RESULT" | ||||||
| echo "model: $MODEL" | echo "model: $MODEL" | ||||||
|  |  | ||||||
| rm -rf $LOG_FOLDER | rm -rf $LOG_FOLDER | ||||||
|  | rm -rf $PROFILE_PATH | ||||||
| mkdir -p $LOG_FOLDER | mkdir -p $LOG_FOLDER | ||||||
|  | mkdir -p $PROFILE_PATH | ||||||
|  |  | ||||||
| cd "$BASE/vllm" | cd "$BASE/vllm" | ||||||
|  |  | ||||||
| @ -70,10 +75,11 @@ start_server() { | |||||||
|     local max_num_seqs=$2 |     local max_num_seqs=$2 | ||||||
|     local max_num_batched_tokens=$3 |     local max_num_batched_tokens=$3 | ||||||
|     local vllm_log=$4 |     local vllm_log=$4 | ||||||
|  |     local profile_dir=$5 | ||||||
|      |      | ||||||
|     pkill -f vllm |     pkill -f vllm | ||||||
|  |  | ||||||
|     VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \ |     VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \ | ||||||
|         --disable-log-requests \ |         --disable-log-requests \ | ||||||
|         --port 8004 \ |         --port 8004 \ | ||||||
|         --gpu-memory-utilization $gpu_memory_utilization \ |         --gpu-memory-utilization $gpu_memory_utilization \ | ||||||
| @ -105,19 +111,37 @@ start_server() { | |||||||
|     fi |     fi | ||||||
| } | } | ||||||
|  |  | ||||||
|  | update_best_profile() { | ||||||
|  |     local profile_dir=$1 | ||||||
|  |     local profile_index=$2 | ||||||
|  |     sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort)) | ||||||
|  |     selected_profile_file= | ||||||
|  |     if [[ "$SYSTEM" == "TPU" ]]; then | ||||||
|  |         selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb" | ||||||
|  |     fi  | ||||||
|  |     if [[ "$SYSTEM" == "GPU" ]]; then | ||||||
|  |         selected_profile_file="${sorted_paths[$profile_index]}" | ||||||
|  |     fi  | ||||||
|  |     rm -f $PROFILE_PATH/* | ||||||
|  |     cp $selected_profile_file $PROFILE_PATH | ||||||
|  | } | ||||||
|  |  | ||||||
| run_benchmark() { | run_benchmark() { | ||||||
|     local max_num_seqs=$1 |     local max_num_seqs=$1 | ||||||
|     local max_num_batched_tokens=$2 |     local max_num_batched_tokens=$2 | ||||||
|     local gpu_memory_utilization=$3 |     local gpu_memory_utilization=$3 | ||||||
|     echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" |     echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" | ||||||
|     local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" |     local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" | ||||||
|  |     local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}" | ||||||
|     echo "vllm_log: $vllm_log" |     echo "vllm_log: $vllm_log" | ||||||
|     echo |     echo | ||||||
|     rm -f $vllm_log |     rm -f $vllm_log | ||||||
|  |     mkdir -p $profile_dir | ||||||
|     pkill -f vllm |     pkill -f vllm | ||||||
|  |     local profile_index=0 | ||||||
|  |  | ||||||
|     echo "starting server..." |     echo "starting server..." | ||||||
|     start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log |     start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir | ||||||
|     result=$? |     result=$? | ||||||
|     if [[ "$result" -eq 1 ]]; then |     if [[ "$result" -eq 1 ]]; then | ||||||
|         echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" |         echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" | ||||||
| @ -144,7 +168,8 @@ run_benchmark() { | |||||||
|         --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ |         --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ | ||||||
|         --num-prompts 1000 \ |         --num-prompts 1000 \ | ||||||
|         --random-prefix-len $prefix_len \ |         --random-prefix-len $prefix_len \ | ||||||
|         --port 8004 &> "$bm_log" |         --port 8004 \ | ||||||
|  |         --profile &> "$bm_log" | ||||||
|     throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') |     throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') | ||||||
|     e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') |     e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') | ||||||
|     goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') |     goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') | ||||||
| @ -158,6 +183,7 @@ run_benchmark() { | |||||||
|     # start from request-rate as int(throughput) + 1 |     # start from request-rate as int(throughput) + 1 | ||||||
|         request_rate=$((${throughput%.*} + 1)) |         request_rate=$((${throughput%.*} + 1)) | ||||||
|         while ((request_rate > 0)); do |         while ((request_rate > 0)); do | ||||||
|  |             profile_index=$((profile_index+1)) | ||||||
|             # clear prefix cache |             # clear prefix cache | ||||||
|             curl -X POST http://0.0.0.0:8004/reset_prefix_cache |             curl -X POST http://0.0.0.0:8004/reset_prefix_cache | ||||||
|             sleep 5 |             sleep 5 | ||||||
| @ -195,6 +221,12 @@ run_benchmark() { | |||||||
|             best_max_num_seqs=$max_num_seqs |             best_max_num_seqs=$max_num_seqs | ||||||
|             best_num_batched_tokens=$max_num_batched_tokens |             best_num_batched_tokens=$max_num_batched_tokens | ||||||
|             best_goodput=$goodput |             best_goodput=$goodput | ||||||
|  |             if [[ "$SYSTEM" == "TPU" ]]; then | ||||||
|  |                 update_best_profile "$profile_dir/plugins/profile" $profile_index | ||||||
|  |             fi | ||||||
|  |             if [[ "$SYSTEM" == "GPU" ]]; then | ||||||
|  |                 update_best_profile "$profile_dir" $profile_index | ||||||
|  |             fi | ||||||
|         fi |         fi | ||||||
|     else |     else | ||||||
|         echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" |         echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" | ||||||
| @ -239,6 +271,6 @@ for num_seqs in "${num_seqs_list[@]}"; do | |||||||
|     done |     done | ||||||
| done | done | ||||||
| echo "finish permutations" | echo "finish permutations" | ||||||
| echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" | echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" | ||||||
| echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT" | echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" | ||||||
|  |  | ||||||
|  | |||||||
| @ -404,8 +404,14 @@ async def async_request_openai_chat_completions( | |||||||
|                         chunk_bytes = chunk_bytes.strip() |                         chunk_bytes = chunk_bytes.strip() | ||||||
|                         if not chunk_bytes: |                         if not chunk_bytes: | ||||||
|                             continue |                             continue | ||||||
|  |                         chunk_bytes = chunk_bytes.decode("utf-8") | ||||||
|  |                         # NOTE: SSE comments (often used as pings) start with a colon. | ||||||
|  |                         # These are not JSON data payload and should be skipped. | ||||||
|  |                         if chunk_bytes.startswith(":"): | ||||||
|  |                             continue | ||||||
|  |  | ||||||
|  |                         chunk = chunk_bytes.removeprefix("data: ") | ||||||
|  |  | ||||||
|                         chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") |  | ||||||
|                         if chunk != "[DONE]": |                         if chunk != "[DONE]": | ||||||
|                             timestamp = time.perf_counter() |                             timestamp = time.perf_counter() | ||||||
|                             data = json.loads(chunk) |                             data = json.loads(chunk) | ||||||
|  | |||||||
| @ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset): | |||||||
|         input_low = int(real_input_len * (1 - range_ratio)) |         input_low = int(real_input_len * (1 - range_ratio)) | ||||||
|         input_high = int(real_input_len * (1 + range_ratio)) |         input_high = int(real_input_len * (1 + range_ratio)) | ||||||
|         output_low = int(output_len * (1 - range_ratio)) |         output_low = int(output_len * (1 - range_ratio)) | ||||||
|  |         # Ensure the lower bound for output length is at least 1 to prevent | ||||||
|  |         # sampling 0 tokens, which can cause request failures. | ||||||
|  |         output_low = max(output_low, 1) | ||||||
|         output_high = int(output_len * (1 + range_ratio)) |         output_high = int(output_len * (1 + range_ratio)) | ||||||
|  |  | ||||||
|         # Add logging for debugging |         # Add logging for debugging | ||||||
| @ -349,11 +352,12 @@ class RandomDataset(BenchmarkDataset): | |||||||
|             # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] |             # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] | ||||||
|             # To avoid uncontrolled change of the prompt length, |             # To avoid uncontrolled change of the prompt length, | ||||||
|             # the encoded sequence is truncated before being decode again. |             # the encoded sequence is truncated before being decode again. | ||||||
|  |             total_input_len = prefix_len + int(input_lens[i]) | ||||||
|             re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ |             re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ | ||||||
|                 : input_lens[i] |                 :total_input_len | ||||||
|             ] |             ] | ||||||
|             prompt = tokenizer.decode(re_encoded_sequence) |             prompt = tokenizer.decode(re_encoded_sequence) | ||||||
|             total_input_len = prefix_len + int(input_lens[i]) |             total_input_len = len(re_encoded_sequence) | ||||||
|             requests.append( |             requests.append( | ||||||
|                 SampleRequest( |                 SampleRequest( | ||||||
|                     prompt=prompt, |                     prompt=prompt, | ||||||
| @ -700,6 +704,7 @@ class HuggingFaceDataset(BenchmarkDataset): | |||||||
|         self, |         self, | ||||||
|         dataset_path: str, |         dataset_path: str, | ||||||
|         dataset_split: str, |         dataset_split: str, | ||||||
|  |         no_stream: bool = False, | ||||||
|         dataset_subset: Optional[str] = None, |         dataset_subset: Optional[str] = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ) -> None: |     ) -> None: | ||||||
| @ -707,6 +712,7 @@ class HuggingFaceDataset(BenchmarkDataset): | |||||||
|  |  | ||||||
|         self.dataset_split = dataset_split |         self.dataset_split = dataset_split | ||||||
|         self.dataset_subset = dataset_subset |         self.dataset_subset = dataset_subset | ||||||
|  |         self.load_stream = not no_stream | ||||||
|         self.load_data() |         self.load_data() | ||||||
|  |  | ||||||
|     def load_data(self) -> None: |     def load_data(self) -> None: | ||||||
| @ -715,7 +721,7 @@ class HuggingFaceDataset(BenchmarkDataset): | |||||||
|             self.dataset_path, |             self.dataset_path, | ||||||
|             name=self.dataset_subset, |             name=self.dataset_subset, | ||||||
|             split=self.dataset_split, |             split=self.dataset_split, | ||||||
|             streaming=True, |             streaming=self.load_stream, | ||||||
|         ) |         ) | ||||||
|         self.data = self.data.shuffle(seed=self.random_seed) |         self.data = self.data.shuffle(seed=self.random_seed) | ||||||
|  |  | ||||||
|  | |||||||
| @ -123,7 +123,7 @@ def main(args: argparse.Namespace): | |||||||
|         save_to_pytorch_benchmark_format(args, results) |         save_to_pytorch_benchmark_format(args, results) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | def create_argument_parser(): | ||||||
|     parser = FlexibleArgumentParser( |     parser = FlexibleArgumentParser( | ||||||
|         description="Benchmark the latency of processing a single batch of " |         description="Benchmark the latency of processing a single batch of " | ||||||
|         "requests till completion." |         "requests till completion." | ||||||
| @ -171,6 +171,12 @@ if __name__ == "__main__": | |||||||
|     # V1 enables prefix caching by default which skews the latency |     # V1 enables prefix caching by default which skews the latency | ||||||
|     # numbers. We need to disable prefix caching by default. |     # numbers. We need to disable prefix caching by default. | ||||||
|     parser.set_defaults(enable_prefix_caching=False) |     parser.set_defaults(enable_prefix_caching=False) | ||||||
|  |  | ||||||
|  |     return parser | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = create_argument_parser() | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: |     if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: | ||||||
|         raise OSError( |         raise OSError( | ||||||
|  | |||||||
| @ -142,7 +142,7 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | def create_argument_parser(): | ||||||
|     parser = FlexibleArgumentParser( |     parser = FlexibleArgumentParser( | ||||||
|         description="Benchmark the performance with or " |         description="Benchmark the performance with or " | ||||||
|         "without automatic prefix caching." |         "without automatic prefix caching." | ||||||
| @ -192,5 +192,11 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     parser = EngineArgs.add_cli_args(parser) |     parser = EngineArgs.add_cli_args(parser) | ||||||
|  |  | ||||||
|  |     return parser | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = create_argument_parser() | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     main(args) |     main(args) | ||||||
|  | |||||||
| @ -218,7 +218,7 @@ def main(args): | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | def create_argument_parser(): | ||||||
|     parser = FlexibleArgumentParser( |     parser = FlexibleArgumentParser( | ||||||
|         description="Benchmark the performance with or without " |         description="Benchmark the performance with or without " | ||||||
|         "automatic prefix caching." |         "automatic prefix caching." | ||||||
| @ -268,5 +268,11 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     parser = EngineArgs.add_cli_args(parser) |     parser = EngineArgs.add_cli_args(parser) | ||||||
|  |  | ||||||
|  |     return parser | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = create_argument_parser() | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     main(args) |     main(args) | ||||||
|  | |||||||
| @ -161,7 +161,7 @@ def main(args: argparse.Namespace): | |||||||
|             json.dump(results, f, indent=4) |             json.dump(results, f, indent=4) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | def create_argument_parser(): | ||||||
|     parser = FlexibleArgumentParser(description="Benchmark the throughput.") |     parser = FlexibleArgumentParser(description="Benchmark the throughput.") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" |         "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" | ||||||
| @ -204,6 +204,12 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     parser = EngineArgs.add_cli_args(parser) |     parser = EngineArgs.add_cli_args(parser) | ||||||
|  |  | ||||||
|  |     return parser | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = create_argument_parser() | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     if args.tokenizer is None: |     if args.tokenizer is None: | ||||||
|         args.tokenizer = args.model |         args.tokenizer = args.model | ||||||
|  | |||||||
| @ -33,7 +33,7 @@ import warnings | |||||||
| from collections.abc import AsyncGenerator, Iterable | from collections.abc import AsyncGenerator, Iterable | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| from datetime import datetime | from datetime import datetime | ||||||
| from typing import Any, Optional | from typing import Any, Literal, Optional | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| from tqdm.asyncio import tqdm | from tqdm.asyncio import tqdm | ||||||
| @ -107,14 +107,42 @@ class BenchmarkMetrics: | |||||||
|     percentiles_e2el_ms: list[tuple[float, float]] |     percentiles_e2el_ms: list[tuple[float, float]] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _get_current_request_rate( | ||||||
|  |     ramp_up_strategy: Optional[Literal["linear", "exponential"]], | ||||||
|  |     ramp_up_start_rps: Optional[int], | ||||||
|  |     ramp_up_end_rps: Optional[int], | ||||||
|  |     request_index: int, | ||||||
|  |     total_requests: int, | ||||||
|  |     request_rate: float, | ||||||
|  | ) -> float: | ||||||
|  |     if ( | ||||||
|  |         ramp_up_strategy | ||||||
|  |         and ramp_up_start_rps is not None | ||||||
|  |         and ramp_up_end_rps is not None | ||||||
|  |     ): | ||||||
|  |         progress = request_index / max(total_requests - 1, 1) | ||||||
|  |         if ramp_up_strategy == "linear": | ||||||
|  |             increase = (ramp_up_end_rps - ramp_up_start_rps) * progress | ||||||
|  |             return ramp_up_start_rps + increase | ||||||
|  |         elif ramp_up_strategy == "exponential": | ||||||
|  |             ratio = ramp_up_end_rps / ramp_up_start_rps | ||||||
|  |             return ramp_up_start_rps * (ratio**progress) | ||||||
|  |         else: | ||||||
|  |             raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") | ||||||
|  |     return request_rate | ||||||
|  |  | ||||||
|  |  | ||||||
| async def get_request( | async def get_request( | ||||||
|     input_requests: list[SampleRequest], |     input_requests: list[SampleRequest], | ||||||
|     request_rate: float, |     request_rate: float, | ||||||
|     burstiness: float = 1.0, |     burstiness: float = 1.0, | ||||||
| ) -> AsyncGenerator[SampleRequest, None]: |     ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, | ||||||
|  |     ramp_up_start_rps: Optional[int] = None, | ||||||
|  |     ramp_up_end_rps: Optional[int] = None, | ||||||
|  | ) -> AsyncGenerator[tuple[SampleRequest, float], None]: | ||||||
|     """ |     """ | ||||||
|     Asynchronously generates requests at a specified rate |     Asynchronously generates requests at a specified rate | ||||||
|     with OPTIONAL burstiness. |     with OPTIONAL burstiness and OPTIONAL ramp-up strategy. | ||||||
|  |  | ||||||
|     Args: |     Args: | ||||||
|         input_requests: |         input_requests: | ||||||
| @ -129,22 +157,44 @@ async def get_request( | |||||||
|             A lower burstiness value (0 < burstiness < 1) results |             A lower burstiness value (0 < burstiness < 1) results | ||||||
|             in more bursty requests, while a higher burstiness value |             in more bursty requests, while a higher burstiness value | ||||||
|             (burstiness > 1) results in a more uniform arrival of requests. |             (burstiness > 1) results in a more uniform arrival of requests. | ||||||
|  |          ramp_up_strategy (optional): | ||||||
|  |             The ramp-up strategy. Can be "linear" or "exponential". | ||||||
|  |             If None, uses constant request rate (specified by request_rate). | ||||||
|  |         ramp_up_start_rps (optional): | ||||||
|  |             The starting request rate for ramp-up. | ||||||
|  |         ramp_up_end_rps (optional): | ||||||
|  |             The ending request rate for ramp-up. | ||||||
|     """ |     """ | ||||||
|     input_requests: Iterable[SampleRequest] = iter(input_requests) |  | ||||||
|  |  | ||||||
|     # Calculate scale parameter theta to maintain the desired request_rate. |  | ||||||
|     assert burstiness > 0, ( |     assert burstiness > 0, ( | ||||||
|         f"A positive burstiness factor is expected, but given {burstiness}." |         f"A positive burstiness factor is expected, but given {burstiness}." | ||||||
|     ) |     ) | ||||||
|     theta = 1.0 / (request_rate * burstiness) |     # Convert to list to get length for ramp-up calculations | ||||||
|  |     if isinstance(input_requests, Iterable) and not isinstance(input_requests, list): | ||||||
|  |         input_requests = list(input_requests) | ||||||
|  |  | ||||||
|  |     total_requests = len(input_requests) | ||||||
|  |     request_index = 0 | ||||||
|  |  | ||||||
|     for request in input_requests: |     for request in input_requests: | ||||||
|         yield request |         current_request_rate = _get_current_request_rate( | ||||||
|  |             ramp_up_strategy, | ||||||
|  |             ramp_up_start_rps, | ||||||
|  |             ramp_up_end_rps, | ||||||
|  |             request_index, | ||||||
|  |             total_requests, | ||||||
|  |             request_rate, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         if request_rate == float("inf"): |         yield request, current_request_rate | ||||||
|  |  | ||||||
|  |         request_index += 1 | ||||||
|  |  | ||||||
|  |         if current_request_rate == float("inf"): | ||||||
|             # If the request rate is infinity, then we don't need to wait. |             # If the request rate is infinity, then we don't need to wait. | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
|  |         theta = 1.0 / (current_request_rate * burstiness) | ||||||
|  |  | ||||||
|         # Sample the request interval from the gamma distribution. |         # Sample the request interval from the gamma distribution. | ||||||
|         # If burstiness is 1, it follows exponential distribution. |         # If burstiness is 1, it follows exponential distribution. | ||||||
|         interval = np.random.gamma(shape=burstiness, scale=theta) |         interval = np.random.gamma(shape=burstiness, scale=theta) | ||||||
| @ -290,6 +340,9 @@ async def benchmark( | |||||||
|     max_concurrency: Optional[int], |     max_concurrency: Optional[int], | ||||||
|     lora_modules: Optional[Iterable[str]], |     lora_modules: Optional[Iterable[str]], | ||||||
|     extra_body: Optional[dict], |     extra_body: Optional[dict], | ||||||
|  |     ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, | ||||||
|  |     ramp_up_start_rps: Optional[int] = None, | ||||||
|  |     ramp_up_end_rps: Optional[int] = None, | ||||||
| ): | ): | ||||||
|     if backend in ASYNC_REQUEST_FUNCS: |     if backend in ASYNC_REQUEST_FUNCS: | ||||||
|         request_func = ASYNC_REQUEST_FUNCS[backend] |         request_func = ASYNC_REQUEST_FUNCS[backend] | ||||||
| @ -353,7 +406,15 @@ async def benchmark( | |||||||
|  |  | ||||||
|     distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" |     distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" | ||||||
|  |  | ||||||
|     print(f"Traffic request rate: {request_rate}") |     if ramp_up_strategy is not None: | ||||||
|  |         print( | ||||||
|  |             f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase " | ||||||
|  |             f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over " | ||||||
|  |             "the duration of the benchmark." | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         print(f"Traffic request rate: {request_rate} RPS.") | ||||||
|  |  | ||||||
|     print(f"Burstiness factor: {burstiness} ({distribution})") |     print(f"Burstiness factor: {burstiness} ({distribution})") | ||||||
|     print(f"Maximum request concurrency: {max_concurrency}") |     print(f"Maximum request concurrency: {max_concurrency}") | ||||||
|  |  | ||||||
| @ -373,7 +434,34 @@ async def benchmark( | |||||||
|  |  | ||||||
|     benchmark_start_time = time.perf_counter() |     benchmark_start_time = time.perf_counter() | ||||||
|     tasks: list[asyncio.Task] = [] |     tasks: list[asyncio.Task] = [] | ||||||
|     async for request in get_request(input_requests, request_rate, burstiness): |  | ||||||
|  |     rps_change_events = [] | ||||||
|  |     last_int_rps = -1 | ||||||
|  |     if ramp_up_strategy is not None and ramp_up_start_rps is not None: | ||||||
|  |         last_int_rps = ramp_up_start_rps | ||||||
|  |         rps_change_events.append( | ||||||
|  |             { | ||||||
|  |                 "rps": last_int_rps, | ||||||
|  |                 "timestamp": datetime.now().isoformat(), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     async for request, current_request_rate in get_request( | ||||||
|  |         input_requests, | ||||||
|  |         request_rate, | ||||||
|  |         burstiness, | ||||||
|  |         ramp_up_strategy, | ||||||
|  |         ramp_up_start_rps, | ||||||
|  |         ramp_up_end_rps, | ||||||
|  |     ): | ||||||
|  |         if ramp_up_strategy is not None: | ||||||
|  |             current_int_rps = int(current_request_rate) | ||||||
|  |             if current_int_rps > last_int_rps: | ||||||
|  |                 timestamp = datetime.now().isoformat() | ||||||
|  |                 for rps_val in range(last_int_rps + 1, current_int_rps + 1): | ||||||
|  |                     rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) | ||||||
|  |                 last_int_rps = current_int_rps | ||||||
|  |  | ||||||
|         prompt, prompt_len, output_len, mm_content = ( |         prompt, prompt_len, output_len, mm_content = ( | ||||||
|             request.prompt, |             request.prompt, | ||||||
|             request.prompt_len, |             request.prompt_len, | ||||||
| @ -397,11 +485,8 @@ async def benchmark( | |||||||
|             ignore_eos=ignore_eos, |             ignore_eos=ignore_eos, | ||||||
|             extra_body=extra_body, |             extra_body=extra_body, | ||||||
|         ) |         ) | ||||||
|         tasks.append( |         task = limited_request_func(request_func_input=request_func_input, pbar=pbar) | ||||||
|             asyncio.create_task( |         tasks.append(asyncio.create_task(task)) | ||||||
|                 limited_request_func(request_func_input=request_func_input, pbar=pbar) |  | ||||||
|             ) |  | ||||||
|         ) |  | ||||||
|     outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) |     outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) | ||||||
|  |  | ||||||
|     if profile: |     if profile: | ||||||
| @ -466,7 +551,7 @@ async def benchmark( | |||||||
|         "total_input_tokens": metrics.total_input, |         "total_input_tokens": metrics.total_input, | ||||||
|         "total_output_tokens": metrics.total_output, |         "total_output_tokens": metrics.total_output, | ||||||
|         "request_throughput": metrics.request_throughput, |         "request_throughput": metrics.request_throughput, | ||||||
|         "request_goodput:": metrics.request_goodput if goodput_config_dict else None, |         "request_goodput": metrics.request_goodput if goodput_config_dict else None, | ||||||
|         "output_throughput": metrics.output_throughput, |         "output_throughput": metrics.output_throughput, | ||||||
|         "total_token_throughput": metrics.total_token_throughput, |         "total_token_throughput": metrics.total_token_throughput, | ||||||
|         "input_lens": [output.prompt_len for output in outputs], |         "input_lens": [output.prompt_len for output in outputs], | ||||||
| @ -477,6 +562,9 @@ async def benchmark( | |||||||
|         "errors": [output.error for output in outputs], |         "errors": [output.error for output in outputs], | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     if rps_change_events: | ||||||
|  |         result["rps_change_events"] = rps_change_events | ||||||
|  |  | ||||||
|     def process_one_metric( |     def process_one_metric( | ||||||
|         # E.g., "ttft" |         # E.g., "ttft" | ||||||
|         metric_attribute_name: str, |         metric_attribute_name: str, | ||||||
| @ -610,6 +698,26 @@ def main(args: argparse.Namespace): | |||||||
|     tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model |     tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model | ||||||
|     tokenizer_mode = args.tokenizer_mode |     tokenizer_mode = args.tokenizer_mode | ||||||
|  |  | ||||||
|  |     # Validate ramp-up arguments | ||||||
|  |     if args.ramp_up_strategy is not None: | ||||||
|  |         if args.request_rate != float("inf"): | ||||||
|  |             raise ValueError( | ||||||
|  |                 "When using ramp-up, do not specify --request-rate. " | ||||||
|  |                 "The request rate will be controlled by ramp-up parameters. " | ||||||
|  |                 "Please remove the --request-rate argument." | ||||||
|  |             ) | ||||||
|  |         if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "When using --ramp-up-strategy, both --ramp-up-start-rps and " | ||||||
|  |                 "--ramp-up-end-rps must be specified" | ||||||
|  |             ) | ||||||
|  |         if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: | ||||||
|  |             raise ValueError("Ramp-up start and end RPS must be non-negative") | ||||||
|  |         if args.ramp_up_start_rps > args.ramp_up_end_rps: | ||||||
|  |             raise ValueError("Ramp-up start RPS must be less than end RPS") | ||||||
|  |         if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: | ||||||
|  |             raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") | ||||||
|  |  | ||||||
|     if args.base_url is not None: |     if args.base_url is not None: | ||||||
|         api_url = f"{args.base_url}{args.endpoint}" |         api_url = f"{args.base_url}{args.endpoint}" | ||||||
|         base_url = f"{args.base_url}" |         base_url = f"{args.base_url}" | ||||||
| @ -717,6 +825,7 @@ def main(args: argparse.Namespace): | |||||||
|             dataset_subset=args.hf_subset, |             dataset_subset=args.hf_subset, | ||||||
|             dataset_split=args.hf_split, |             dataset_split=args.hf_split, | ||||||
|             random_seed=args.seed, |             random_seed=args.seed, | ||||||
|  |             no_stream=args.no_stream, | ||||||
|         ).sample( |         ).sample( | ||||||
|             num_requests=args.num_prompts, |             num_requests=args.num_prompts, | ||||||
|             tokenizer=tokenizer, |             tokenizer=tokenizer, | ||||||
| @ -802,6 +911,9 @@ def main(args: argparse.Namespace): | |||||||
|             max_concurrency=args.max_concurrency, |             max_concurrency=args.max_concurrency, | ||||||
|             lora_modules=args.lora_modules, |             lora_modules=args.lora_modules, | ||||||
|             extra_body=sampling_params, |             extra_body=sampling_params, | ||||||
|  |             ramp_up_strategy=args.ramp_up_strategy, | ||||||
|  |             ramp_up_start_rps=args.ramp_up_start_rps, | ||||||
|  |             ramp_up_end_rps=args.ramp_up_end_rps, | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
| @ -834,6 +946,11 @@ def main(args: argparse.Namespace): | |||||||
|         result_json["burstiness"] = args.burstiness |         result_json["burstiness"] = args.burstiness | ||||||
|         result_json["max_concurrency"] = args.max_concurrency |         result_json["max_concurrency"] = args.max_concurrency | ||||||
|  |  | ||||||
|  |         if args.ramp_up_strategy is not None: | ||||||
|  |             result_json["ramp_up_strategy"] = args.ramp_up_strategy | ||||||
|  |             result_json["ramp_up_start_rps"] = args.ramp_up_start_rps | ||||||
|  |             result_json["ramp_up_end_rps"] = args.ramp_up_end_rps | ||||||
|  |  | ||||||
|         # Merge with benchmark result |         # Merge with benchmark result | ||||||
|         result_json = {**result_json, **benchmark_result} |         result_json = {**result_json, **benchmark_result} | ||||||
|  |  | ||||||
| @ -859,7 +976,10 @@ def main(args: argparse.Namespace): | |||||||
|             if args.max_concurrency is not None |             if args.max_concurrency is not None | ||||||
|             else "" |             else "" | ||||||
|         ) |         ) | ||||||
|         file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa |         if args.ramp_up_strategy is not None: | ||||||
|  |             file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa | ||||||
|  |         else: | ||||||
|  |             file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa | ||||||
|         if args.result_filename: |         if args.result_filename: | ||||||
|             file_name = args.result_filename |             file_name = args.result_filename | ||||||
|         if args.result_dir: |         if args.result_dir: | ||||||
| @ -875,7 +995,7 @@ def main(args: argparse.Namespace): | |||||||
|         save_to_pytorch_benchmark_format(args, result_json, file_name) |         save_to_pytorch_benchmark_format(args, result_json, file_name) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | def create_argument_parser(): | ||||||
|     parser = FlexibleArgumentParser( |     parser = FlexibleArgumentParser( | ||||||
|         description="Benchmark the online serving throughput." |         description="Benchmark the online serving throughput." | ||||||
|     ) |     ) | ||||||
| @ -914,6 +1034,11 @@ if __name__ == "__main__": | |||||||
|         help="Path to the sharegpt/sonnet dataset. " |         help="Path to the sharegpt/sonnet dataset. " | ||||||
|         "Or the huggingface dataset ID if using HF dataset.", |         "Or the huggingface dataset ID if using HF dataset.", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--no-stream", | ||||||
|  |         action="store_true", | ||||||
|  |         help="Do not load the dataset in streaming mode.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--max-concurrency", |         "--max-concurrency", | ||||||
|         type=int, |         type=int, | ||||||
| @ -1225,6 +1350,35 @@ if __name__ == "__main__": | |||||||
|         "script chooses a LoRA module at random.", |         "script chooses a LoRA module at random.", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     args = parser.parse_args() |     parser.add_argument( | ||||||
|  |         "--ramp-up-strategy", | ||||||
|  |         type=str, | ||||||
|  |         default=None, | ||||||
|  |         choices=["linear", "exponential"], | ||||||
|  |         help="The ramp-up strategy. This would be used to " | ||||||
|  |         "ramp up the request rate from initial RPS to final " | ||||||
|  |         "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). " | ||||||
|  |         "over the duration of the benchmark.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--ramp-up-start-rps", | ||||||
|  |         type=int, | ||||||
|  |         default=None, | ||||||
|  |         help="The starting request rate for ramp-up (RPS). " | ||||||
|  |         "Needs to be specified when --ramp-up-strategy is used.", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--ramp-up-end-rps", | ||||||
|  |         type=int, | ||||||
|  |         default=None, | ||||||
|  |         help="The ending request rate for ramp-up (RPS). " | ||||||
|  |         "Needs to be specified when --ramp-up-strategy is used.", | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     return parser | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = create_argument_parser() | ||||||
|  |     args = parser.parse_args() | ||||||
|     main(args) |     main(args) | ||||||
|  | |||||||
| @ -850,7 +850,7 @@ def main(args: argparse.Namespace): | |||||||
|             json.dump(results, outfile, indent=4) |             json.dump(results, outfile, indent=4) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | def create_argument_parser(): | ||||||
|     parser = FlexibleArgumentParser( |     parser = FlexibleArgumentParser( | ||||||
|         description="Benchmark the online serving throughput." |         description="Benchmark the online serving throughput." | ||||||
|     ) |     ) | ||||||
| @ -1034,5 +1034,10 @@ if __name__ == "__main__": | |||||||
|         help="Ratio of Structured Outputs requests", |         help="Ratio of Structured Outputs requests", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     return parser | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = create_argument_parser() | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     main(args) |     main(args) | ||||||
|  | |||||||
| @ -97,7 +97,7 @@ def run_vllm( | |||||||
|         assert lora_requests is None, "BeamSearch API does not support LoRA" |         assert lora_requests is None, "BeamSearch API does not support LoRA" | ||||||
|         prompts = [request.prompt for request in requests] |         prompts = [request.prompt for request in requests] | ||||||
|         # output_len should be the same for all requests. |         # output_len should be the same for all requests. | ||||||
|         output_len = requests[0][2] |         output_len = requests[0].expected_output_len | ||||||
|         for request in requests: |         for request in requests: | ||||||
|             assert request.expected_output_len == output_len |             assert request.expected_output_len == output_len | ||||||
|         start = time.perf_counter() |         start = time.perf_counter() | ||||||
| @ -356,6 +356,7 @@ def get_requests(args, tokenizer): | |||||||
|     elif args.dataset_name == "burstgpt": |     elif args.dataset_name == "burstgpt": | ||||||
|         dataset_cls = BurstGPTDataset |         dataset_cls = BurstGPTDataset | ||||||
|     elif args.dataset_name == "hf": |     elif args.dataset_name == "hf": | ||||||
|  |         common_kwargs["no_stream"] = args.no_stream | ||||||
|         if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: |         if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: | ||||||
|             dataset_cls = VisionArenaDataset |             dataset_cls = VisionArenaDataset | ||||||
|             common_kwargs["dataset_subset"] = None |             common_kwargs["dataset_subset"] = None | ||||||
| @ -595,7 +596,7 @@ def validate_args(args): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | def create_argument_parser(): | ||||||
|     parser = FlexibleArgumentParser(description="Benchmark the throughput.") |     parser = FlexibleArgumentParser(description="Benchmark the throughput.") | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--backend", |         "--backend", | ||||||
| @ -610,6 +611,11 @@ if __name__ == "__main__": | |||||||
|         help="Name of the dataset to benchmark on.", |         help="Name of the dataset to benchmark on.", | ||||||
|         default="sharegpt", |         default="sharegpt", | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--no-stream", | ||||||
|  |         action="store_true", | ||||||
|  |         help="Do not load the dataset in streaming mode.", | ||||||
|  |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--dataset", |         "--dataset", | ||||||
|         type=str, |         type=str, | ||||||
| @ -717,6 +723,12 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     parser = AsyncEngineArgs.add_cli_args(parser) |     parser = AsyncEngineArgs.add_cli_args(parser) | ||||||
|  |  | ||||||
|  |     return parser | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = create_argument_parser() | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     if args.tokenizer is None: |     if args.tokenizer is None: | ||||||
|         args.tokenizer = args.model |         args.tokenizer = args.model | ||||||
|  | |||||||
| @ -19,7 +19,7 @@ from vllm import _custom_ops as ops | |||||||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||||||
|     w8a8_block_fp8_matmul, |     w8a8_block_fp8_matmul, | ||||||
| ) | ) | ||||||
| from vllm.utils import FlexibleArgumentParser | from vllm.utils import FlexibleArgumentParser, cdiv | ||||||
|  |  | ||||||
| DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) | DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) | ||||||
| DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] | DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] | ||||||
| @ -117,14 +117,9 @@ def bench_fp8( | |||||||
|     scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) |     scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||||
|     scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) |     scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||||
|  |  | ||||||
|     def ceil_div(x: int, y: int) -> int: |     block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32) | ||||||
|         return (x + y - 1) // y |  | ||||||
|  |  | ||||||
|     block_scale_a = torch.rand( |  | ||||||
|         (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 |  | ||||||
|     ) |  | ||||||
|     block_scale_b = torch.rand( |     block_scale_b = torch.rand( | ||||||
|         ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 |         cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32 | ||||||
|     ) |     ) | ||||||
|     block_scale_a_M_major = block_scale_a.t().contiguous().t() |     block_scale_a_M_major = block_scale_a.t().contiguous().t() | ||||||
|     block_scale_b_K_major = block_scale_b.t().contiguous().t() |     block_scale_b_K_major = block_scale_b.t().contiguous().t() | ||||||
|  | |||||||
| @ -11,6 +11,80 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm | |||||||
| from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant | from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant | ||||||
| from vllm.triton_utils import triton | from vllm.triton_utils import triton | ||||||
|  |  | ||||||
|  | PROVIDER_CFGS = { | ||||||
|  |     "torch-bf16": dict(enabled=True), | ||||||
|  |     "fp8-tensor-w-token-a": dict( | ||||||
|  |         w="tensor", a="token", no_a_quant=False, enabled=False | ||||||
|  |     ), | ||||||
|  |     "fp8-tensor-w-tensor-a": dict( | ||||||
|  |         w="tensor", a="tensor", no_a_quant=False, enabled=True | ||||||
|  |     ), | ||||||
|  |     "fp8-channel-w-token-a": dict( | ||||||
|  |         w="channel", a="token", no_a_quant=False, enabled=True | ||||||
|  |     ), | ||||||
|  |     "fp8-channel-w-tensor-a": dict( | ||||||
|  |         w="channel", a="tensor", no_a_quant=False, enabled=False | ||||||
|  |     ), | ||||||
|  |     "fp8-tensor-w-token-a-noquant": dict( | ||||||
|  |         w="tensor", a="token", no_a_quant=True, enabled=False | ||||||
|  |     ), | ||||||
|  |     "fp8-tensor-w-tensor-a-noquant": dict( | ||||||
|  |         w="tensor", a="tensor", no_a_quant=True, enabled=True | ||||||
|  |     ), | ||||||
|  |     "fp8-channel-w-token-a-noquant": dict( | ||||||
|  |         w="channel", a="token", no_a_quant=True, enabled=True | ||||||
|  |     ), | ||||||
|  |     "fp8-channel-w-tensor-a-noquant": dict( | ||||||
|  |         w="channel", a="tensor", no_a_quant=True, enabled=False | ||||||
|  |     ), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | _enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str): | ||||||
|  |     if w_type == "tensor": | ||||||
|  |         scale_b = torch.ones(1, device=device, dtype=torch.float32) | ||||||
|  |         b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||||
|  |     else: | ||||||
|  |         b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True) | ||||||
|  |     return b_fp8.t(), scale_b_fp8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_fp8_runner(cfg, a, b, dtype, device): | ||||||
|  |     b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device) | ||||||
|  |  | ||||||
|  |     scale_a_const = ( | ||||||
|  |         torch.ones(1, device=device, dtype=torch.float32) | ||||||
|  |         if cfg["a"] == "tensor" | ||||||
|  |         else None | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     if cfg["no_a_quant"]: | ||||||
|  |         if cfg["a"] == "tensor": | ||||||
|  |             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) | ||||||
|  |         else: | ||||||
|  |             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) | ||||||
|  |  | ||||||
|  |         def run(): | ||||||
|  |             return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||||
|  |  | ||||||
|  |         return run | ||||||
|  |  | ||||||
|  |     if cfg["a"] == "tensor": | ||||||
|  |  | ||||||
|  |         def run(): | ||||||
|  |             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) | ||||||
|  |             return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||||
|  |  | ||||||
|  |     else: | ||||||
|  |  | ||||||
|  |         def run(): | ||||||
|  |             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) | ||||||
|  |             return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||||
|  |  | ||||||
|  |     return run | ||||||
|  |  | ||||||
|  |  | ||||||
| @triton.testing.perf_report( | @triton.testing.perf_report( | ||||||
|     triton.testing.Benchmark( |     triton.testing.Benchmark( | ||||||
| @ -18,28 +92,8 @@ from vllm.triton_utils import triton | |||||||
|         x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], |         x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], | ||||||
|         x_log=False, |         x_log=False, | ||||||
|         line_arg="provider", |         line_arg="provider", | ||||||
|         line_vals=[ |         line_vals=_enabled, | ||||||
|             "torch-bf16", |         line_names=_enabled, | ||||||
|             # "fp8-tensor-w-token-a", |  | ||||||
|             "fp8-tensor-w-tensor-a", |  | ||||||
|             "fp8-channel-w-token-a", |  | ||||||
|             # "fp8-channel-w-tensor-a", |  | ||||||
|             # "fp8-tensor-w-token-a-noquant", |  | ||||||
|             "fp8-tensor-w-tensor-a-noquant", |  | ||||||
|             "fp8-channel-w-token-a-noquant", |  | ||||||
|             # "fp8-channel-w-tensor-a-noquant", |  | ||||||
|         ], |  | ||||||
|         line_names=[ |  | ||||||
|             "torch-bf16", |  | ||||||
|             # "fp8-tensor-w-token-a", |  | ||||||
|             "fp8-tensor-w-tensor-a", |  | ||||||
|             "fp8-channel-w-token-a", |  | ||||||
|             # "fp8-channel-w-tensor-a", |  | ||||||
|             # "fp8-tensor-w-token-a-noquant", |  | ||||||
|             "fp8-tensor-w-tensor-a-noquant", |  | ||||||
|             "fp8-channel-w-token-a-noquant", |  | ||||||
|             # "fp8-channel-w-tensor-a-noquant", |  | ||||||
|         ], |  | ||||||
|         ylabel="TFLOP/s (larger is better)", |         ylabel="TFLOP/s (larger is better)", | ||||||
|         plot_name="BF16 vs FP8 GEMMs", |         plot_name="BF16 vs FP8 GEMMs", | ||||||
|         args={}, |         args={}, | ||||||
| @ -50,144 +104,34 @@ def benchmark(batch_size, provider, N, K): | |||||||
|     device = "cuda" |     device = "cuda" | ||||||
|     dtype = torch.bfloat16 |     dtype = torch.bfloat16 | ||||||
|  |  | ||||||
|     # Create input tensors |  | ||||||
|     a = torch.randn((M, K), device=device, dtype=dtype) |     a = torch.randn((M, K), device=device, dtype=dtype) | ||||||
|     b = torch.randn((N, K), device=device, dtype=dtype) |     b = torch.randn((N, K), device=device, dtype=dtype) | ||||||
|  |  | ||||||
|     quantiles = [0.5, 0.2, 0.8] |     quantiles = [0.5, 0.2, 0.8] | ||||||
|  |  | ||||||
|     if "torch-bf16" in provider: |     if provider == "torch-bf16": | ||||||
|         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( |         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||||
|             lambda: torch.nn.functional.linear(a, b), quantiles=quantiles |             lambda: torch.nn.functional.linear(a, b), quantiles=quantiles | ||||||
|         ) |         ) | ||||||
|  |     else: | ||||||
|     elif "fp8" in provider: |         cfg = PROVIDER_CFGS[provider] | ||||||
|         # Weights are always quantized ahead of time |         run_quant = build_fp8_runner(cfg, a, b, dtype, device) | ||||||
|         if "noquant" in provider: |  | ||||||
|             # For no quantization, we just measure the GEMM |  | ||||||
|             if "tensor-w-token-a" in provider: |  | ||||||
|                 # Dynamic per-token quant for A, per-tensor quant for B |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) |  | ||||||
|                 assert scale_b_fp8.numel() == 1 |  | ||||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( |  | ||||||
|                     a, use_per_token_if_dynamic=True |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|             elif "tensor-w-tensor-a" in provider: |  | ||||||
|                 # Static per-tensor quantization with fixed scales |  | ||||||
|                 # for both A and B |  | ||||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) |  | ||||||
|                 scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) |  | ||||||
|                 assert scale_b_fp8.numel() == 1 |  | ||||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|             elif "channel-w-token-a" in provider: |  | ||||||
|                 # Static per-channel quantization for weights, per-token |  | ||||||
|                 # quant for A |  | ||||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) |  | ||||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() |  | ||||||
|                 assert scale_b_fp8.numel() == N |  | ||||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( |  | ||||||
|                     a, use_per_token_if_dynamic=True |  | ||||||
|                 ) |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|             elif "channel-w-tensor-a" in provider: |  | ||||||
|                 # Static per-channel quantization for weights, per-tensor |  | ||||||
|                 # quant for A |  | ||||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) |  | ||||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) |  | ||||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() |  | ||||||
|                 assert scale_b_fp8.numel() == N |  | ||||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|         else: |  | ||||||
|             # In these cases, we quantize the activations during the GEMM call |  | ||||||
|             if "tensor-w-token-a" in provider: |  | ||||||
|                 # Dynamic per-token quant for A, per-tensor quant for B |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) |  | ||||||
|                 assert scale_b_fp8.numel() == 1 |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( |  | ||||||
|                         a, use_per_token_if_dynamic=True |  | ||||||
|                     ) |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|             elif "tensor-w-tensor-a" in provider: |  | ||||||
|                 # Static per-tensor quantization with fixed scales |  | ||||||
|                 # for both A and B |  | ||||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) |  | ||||||
|                 scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) |  | ||||||
|                 assert scale_b_fp8.numel() == 1 |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|             elif "channel-w-token-a" in provider: |  | ||||||
|                 # Static per-channel quantization for weights, per-token |  | ||||||
|                 # quant for A |  | ||||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) |  | ||||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() |  | ||||||
|                 assert scale_b_fp8.numel() == N |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( |  | ||||||
|                         a, use_per_token_if_dynamic=True |  | ||||||
|                     ) |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|             elif "channel-w-tensor-a" in provider: |  | ||||||
|                 # Static per-channel quantization for weights, per-tensor |  | ||||||
|                 # quant for A |  | ||||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) |  | ||||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) |  | ||||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) |  | ||||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() |  | ||||||
|                 assert scale_b_fp8.numel() == N |  | ||||||
|  |  | ||||||
|                 def run_quant(): |  | ||||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) |  | ||||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) |  | ||||||
|  |  | ||||||
|         b_fp8 = b_fp8.t() |  | ||||||
|  |  | ||||||
|         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( |         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||||
|             lambda: run_quant(), quantiles=quantiles |             lambda: run_quant(), quantiles=quantiles | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     # Calculate TFLOP/s, two flops per multiply-add |     to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) | ||||||
|     tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) |     return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) | ||||||
|     return tflops(ms), tflops(max_ms), tflops(min_ms) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def prepare_shapes(args): | def prepare_shapes(args): | ||||||
|     KN_model_names = [] |     out = [] | ||||||
|     models_tps = list(itertools.product(args.models, args.tp_sizes)) |     for model, tp_size in itertools.product(args.models, args.tp_sizes): | ||||||
|     for model, tp_size in models_tps: |         for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||||||
|         assert model in WEIGHT_SHAPES |             KN[tp_dim] //= tp_size | ||||||
|         for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): |  | ||||||
|             KN[tp_split_dim] = KN[tp_split_dim] // tp_size |  | ||||||
|             KN.append(model) |             KN.append(model) | ||||||
|             KN_model_names.append(KN) |             out.append(KN) | ||||||
|     return KN_model_names |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
| @ -197,21 +141,13 @@ if __name__ == "__main__": | |||||||
|         nargs="+", |         nargs="+", | ||||||
|         type=str, |         type=str, | ||||||
|         default=["meta-llama/Llama-3.1-8B-Instruct"], |         default=["meta-llama/Llama-3.1-8B-Instruct"], | ||||||
|         choices=[*WEIGHT_SHAPES.keys()], |         choices=list(WEIGHT_SHAPES.keys()), | ||||||
|         help="List of models to benchmark", |  | ||||||
|     ) |  | ||||||
|     parser.add_argument( |  | ||||||
|         "--tp-sizes", |  | ||||||
|         nargs="+", |  | ||||||
|         type=int, |  | ||||||
|         default=[1], |  | ||||||
|         help="List of tensor parallel sizes", |  | ||||||
|     ) |     ) | ||||||
|  |     parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     KN_model_names = prepare_shapes(args) |     for K, N, model in prepare_shapes(args): | ||||||
|     for K, N, model_name in KN_model_names: |         print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") | ||||||
|         print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") |  | ||||||
|         benchmark.run( |         benchmark.run( | ||||||
|             print_data=True, |             print_data=True, | ||||||
|             show_plots=True, |             show_plots=True, | ||||||
|  | |||||||
							
								
								
									
										169
									
								
								benchmarks/kernels/bench_int8_gemm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								benchmarks/kernels/bench_int8_gemm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,169 @@ | |||||||
|  | # SPDX-License-Identifier: Apache-2.0 | ||||||
|  | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|  | import argparse | ||||||
|  | import copy | ||||||
|  | import itertools | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from weight_shapes import WEIGHT_SHAPES | ||||||
|  |  | ||||||
|  | from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm | ||||||
|  | from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant | ||||||
|  | from vllm.triton_utils import triton | ||||||
|  |  | ||||||
|  | PROVIDER_CFGS = { | ||||||
|  |     "torch-bf16": dict(enabled=True), | ||||||
|  |     "int8-tensor-w-token-a": dict( | ||||||
|  |         w="tensor", a="token", no_a_quant=False, enabled=False | ||||||
|  |     ), | ||||||
|  |     "int8-tensor-w-tensor-a": dict( | ||||||
|  |         w="tensor", a="tensor", no_a_quant=False, enabled=True | ||||||
|  |     ), | ||||||
|  |     "int8-channel-w-token-a": dict( | ||||||
|  |         w="channel", a="token", no_a_quant=False, enabled=True | ||||||
|  |     ), | ||||||
|  |     "int8-channel-w-tensor-a": dict( | ||||||
|  |         w="channel", a="tensor", no_a_quant=False, enabled=False | ||||||
|  |     ), | ||||||
|  |     "int8-tensor-w-token-a-noquant": dict( | ||||||
|  |         w="tensor", a="token", no_a_quant=True, enabled=False | ||||||
|  |     ), | ||||||
|  |     "int8-tensor-w-tensor-a-noquant": dict( | ||||||
|  |         w="tensor", a="tensor", no_a_quant=True, enabled=True | ||||||
|  |     ), | ||||||
|  |     "int8-channel-w-token-a-noquant": dict( | ||||||
|  |         w="channel", a="token", no_a_quant=True, enabled=True | ||||||
|  |     ), | ||||||
|  |     "int8-channel-w-tensor-a-noquant": dict( | ||||||
|  |         w="channel", a="tensor", no_a_quant=True, enabled=False | ||||||
|  |     ), | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _quant_weight(b, w_type, device): | ||||||
|  |     if w_type == "tensor": | ||||||
|  |         scale_b = torch.ones(1, device=device, dtype=torch.float32) | ||||||
|  |         b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) | ||||||
|  |         assert scale_b_int8.numel() == 1 | ||||||
|  |     else:  # channel | ||||||
|  |         b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) | ||||||
|  |         assert scale_b_int8.numel() == b.shape[0] | ||||||
|  |     return b_int8.t(), scale_b_int8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_int8_runner(cfg, a, b, dtype, device): | ||||||
|  |     # quant before running the kernel | ||||||
|  |     b_int8, scale_b_int8 = _quant_weight(b, cfg["w"], device) | ||||||
|  |  | ||||||
|  |     scale_a_const = None | ||||||
|  |     if cfg["a"] == "tensor": | ||||||
|  |         scale_a_const = torch.ones(1, device=device, dtype=torch.float32) | ||||||
|  |  | ||||||
|  |     # no quant, create activation ahead | ||||||
|  |     if cfg["no_a_quant"]: | ||||||
|  |         if cfg["a"] == "tensor": | ||||||
|  |             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) | ||||||
|  |         else:  # token | ||||||
|  |             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) | ||||||
|  |  | ||||||
|  |         def run_quant(): | ||||||
|  |             return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) | ||||||
|  |  | ||||||
|  |         return run_quant | ||||||
|  |  | ||||||
|  |     # dynamic quant, create activation inside | ||||||
|  |     if cfg["a"] == "tensor": | ||||||
|  |  | ||||||
|  |         def run_quant(): | ||||||
|  |             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) | ||||||
|  |             return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) | ||||||
|  |  | ||||||
|  |     else:  # token | ||||||
|  |  | ||||||
|  |         def run_quant(): | ||||||
|  |             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) | ||||||
|  |             return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) | ||||||
|  |  | ||||||
|  |     return run_quant | ||||||
|  |  | ||||||
|  |  | ||||||
|  | _enabled = [k for k, v in PROVIDER_CFGS.items() if v.get("enabled")] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @triton.testing.perf_report( | ||||||
|  |     triton.testing.Benchmark( | ||||||
|  |         x_names=["batch_size"], | ||||||
|  |         x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], | ||||||
|  |         x_log=False, | ||||||
|  |         line_arg="provider", | ||||||
|  |         line_vals=_enabled, | ||||||
|  |         line_names=[k for k in _enabled], | ||||||
|  |         ylabel="TFLOP/s (larger is better)", | ||||||
|  |         plot_name="BF16 vs INT8 GEMMs", | ||||||
|  |         args={}, | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def benchmark(batch_size, provider, N, K): | ||||||
|  |     M = batch_size | ||||||
|  |     device = "cuda" | ||||||
|  |     dtype = torch.bfloat16 | ||||||
|  |     a = torch.randn((M, K), device=device, dtype=dtype) | ||||||
|  |     b = torch.randn((N, K), device=device, dtype=dtype) | ||||||
|  |  | ||||||
|  |     quantiles = [0.5, 0.2, 0.8] | ||||||
|  |  | ||||||
|  |     if provider == "torch-bf16": | ||||||
|  |         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||||
|  |             lambda: torch.nn.functional.linear(a, b), quantiles=quantiles | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         cfg = PROVIDER_CFGS[provider] | ||||||
|  |         run_quant = build_int8_runner(cfg, a, b, dtype, device) | ||||||
|  |         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||||
|  |             lambda: run_quant(), quantiles=quantiles | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) | ||||||
|  |     return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def prepare_shapes(args): | ||||||
|  |     KN_model_names = [] | ||||||
|  |     for model, tp_size in itertools.product(args.models, args.tp_sizes): | ||||||
|  |         for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||||||
|  |             KN[tp_dim] //= tp_size | ||||||
|  |             KN.append(model) | ||||||
|  |             KN_model_names.append(KN) | ||||||
|  |     return KN_model_names | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--models", | ||||||
|  |         nargs="+", | ||||||
|  |         type=str, | ||||||
|  |         default=["meta-llama/Llama-3.1-8B-Instruct"], | ||||||
|  |         choices=list(WEIGHT_SHAPES.keys()), | ||||||
|  |         help="List of models to benchmark", | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--tp-sizes", | ||||||
|  |         nargs="+", | ||||||
|  |         type=int, | ||||||
|  |         default=[1], | ||||||
|  |         help="List of tensor parallel sizes", | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     for K, N, model in prepare_shapes(args): | ||||||
|  |         print(f"{model}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:") | ||||||
|  |         benchmark.run( | ||||||
|  |             print_data=True, | ||||||
|  |             show_plots=True, | ||||||
|  |             save_path=f"bench_int8_res_n{N}_k{K}", | ||||||
|  |             N=N, | ||||||
|  |             K=K, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     print("Benchmark finished!") | ||||||
							
								
								
									
										141
									
								
								benchmarks/kernels/bench_nvfp4_gemm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										141
									
								
								benchmarks/kernels/bench_nvfp4_gemm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,141 @@ | |||||||
|  | # SPDX-License-Identifier: Apache-2.0 | ||||||
|  | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|  | import argparse | ||||||
|  | import copy | ||||||
|  | import itertools | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from weight_shapes import WEIGHT_SHAPES | ||||||
|  |  | ||||||
|  | from vllm import _custom_ops as ops | ||||||
|  | from vllm.platforms import current_platform | ||||||
|  | from vllm.scalar_type import scalar_types | ||||||
|  | from vllm.triton_utils import triton | ||||||
|  |  | ||||||
|  | if not current_platform.has_device_capability(100): | ||||||
|  |     raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() | ||||||
|  | FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max | ||||||
|  |  | ||||||
|  | PROVIDER_CFGS = { | ||||||
|  |     "torch-bf16": dict(enabled=True), | ||||||
|  |     "nvfp4": dict(no_a_quant=False, enabled=True), | ||||||
|  |     "nvfp4-noquant": dict(no_a_quant=True, enabled=True), | ||||||
|  | } | ||||||
|  |  | ||||||
|  | _enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _quant_weight_nvfp4(b: torch.Tensor, device: str): | ||||||
|  |     # Compute global scale for weight | ||||||
|  |     b_amax = torch.abs(b).max().to(torch.float32) | ||||||
|  |     b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax | ||||||
|  |     b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) | ||||||
|  |     return b_fp4, scale_b_fp4, b_global_scale | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def build_nvfp4_runner(cfg, a, b, dtype, device): | ||||||
|  |     b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device) | ||||||
|  |  | ||||||
|  |     # Compute global scale for activation | ||||||
|  |     # NOTE: This is generally provided ahead-of-time by the model checkpoint. | ||||||
|  |     a_amax = torch.abs(a).max().to(torch.float32) | ||||||
|  |     a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax | ||||||
|  |  | ||||||
|  |     # Alpha for the GEMM operation | ||||||
|  |     alpha = 1.0 / (a_global_scale * b_global_scale) | ||||||
|  |  | ||||||
|  |     if cfg["no_a_quant"]: | ||||||
|  |         # Pre-quantize activation | ||||||
|  |         a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) | ||||||
|  |  | ||||||
|  |         def run(): | ||||||
|  |             return ops.cutlass_scaled_fp4_mm( | ||||||
|  |                 a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         return run | ||||||
|  |  | ||||||
|  |     # Quantize activation on-the-fly | ||||||
|  |     def run(): | ||||||
|  |         a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) | ||||||
|  |         return ops.cutlass_scaled_fp4_mm( | ||||||
|  |             a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     return run | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @triton.testing.perf_report( | ||||||
|  |     triton.testing.Benchmark( | ||||||
|  |         x_names=["batch_size"], | ||||||
|  |         x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], | ||||||
|  |         x_log=False, | ||||||
|  |         line_arg="provider", | ||||||
|  |         line_vals=_enabled, | ||||||
|  |         line_names=_enabled, | ||||||
|  |         ylabel="TFLOP/s (larger is better)", | ||||||
|  |         plot_name="BF16 vs NVFP4 GEMMs", | ||||||
|  |         args={}, | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def benchmark(batch_size, provider, N, K): | ||||||
|  |     M = batch_size | ||||||
|  |     device = "cuda" | ||||||
|  |     dtype = torch.bfloat16 | ||||||
|  |  | ||||||
|  |     a = torch.randn((M, K), device=device, dtype=dtype) | ||||||
|  |     b = torch.randn((N, K), device=device, dtype=dtype) | ||||||
|  |  | ||||||
|  |     quantiles = [0.5, 0.2, 0.8] | ||||||
|  |  | ||||||
|  |     if provider == "torch-bf16": | ||||||
|  |         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||||
|  |             lambda: torch.nn.functional.linear(a, b), quantiles=quantiles | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         cfg = PROVIDER_CFGS[provider] | ||||||
|  |         run_quant = build_nvfp4_runner(cfg, a, b, dtype, device) | ||||||
|  |         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||||
|  |             lambda: run_quant(), quantiles=quantiles | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) | ||||||
|  |     return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def prepare_shapes(args): | ||||||
|  |     out = [] | ||||||
|  |     for model, tp_size in itertools.product(args.models, args.tp_sizes): | ||||||
|  |         for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||||||
|  |             KN[tp_dim] //= tp_size | ||||||
|  |             KN.append(model) | ||||||
|  |             out.append(KN) | ||||||
|  |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--models", | ||||||
|  |         nargs="+", | ||||||
|  |         type=str, | ||||||
|  |         default=["meta-llama/Llama-3.1-8B-Instruct"], | ||||||
|  |         choices=list(WEIGHT_SHAPES.keys()), | ||||||
|  |     ) | ||||||
|  |     parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     for K, N, model in prepare_shapes(args): | ||||||
|  |         print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:") | ||||||
|  |         benchmark.run( | ||||||
|  |             print_data=True, | ||||||
|  |             show_plots=True, | ||||||
|  |             save_path=f"bench_nvfp4_res_n{N}_k{K}", | ||||||
|  |             N=N, | ||||||
|  |             K=K, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     print("Benchmark finished!") | ||||||
							
								
								
									
										98
									
								
								benchmarks/kernels/bench_per_token_quant_fp8.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								benchmarks/kernels/bench_per_token_quant_fp8.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,98 @@ | |||||||
|  | # SPDX-License-Identifier: Apache-2.0 | ||||||
|  | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|  | import itertools | ||||||
|  | from typing import Callable | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from vllm import _custom_ops as ops | ||||||
|  | from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config | ||||||
|  | from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 | ||||||
|  | from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape | ||||||
|  | from vllm.triton_utils import triton | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # TODO(luka): use standalone_compile utility | ||||||
|  | def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): | ||||||
|  |     def inner(*args): | ||||||
|  |         torch._dynamo.mark_dynamic(args[arg_index], dim_index) | ||||||
|  |         return fn(*args) | ||||||
|  |  | ||||||
|  |     return inner | ||||||
|  |  | ||||||
|  |  | ||||||
|  | torch._dynamo.config.recompile_limit = 8888 | ||||||
|  | compilation_config = CompilationConfig(custom_ops=["none"]) | ||||||
|  | with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): | ||||||
|  |     torch_per_token_quant_fp8 = torch.compile( | ||||||
|  |         QuantFP8(False, GroupShape.PER_TOKEN), | ||||||
|  |         fullgraph=True, | ||||||
|  |         dynamic=False,  # recompile for different shapes | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # First dim is explicitly dynamic to simulate vLLM usage | ||||||
|  |     torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def cuda_per_token_quant_fp8( | ||||||
|  |     input: torch.Tensor, | ||||||
|  | ) -> tuple[torch.Tensor, torch.Tensor]: | ||||||
|  |     return ops.scaled_fp8_quant(input) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def calculate_diff(batch_size: int, seq_len: int): | ||||||
|  |     """Calculate difference between Triton and CUDA implementations.""" | ||||||
|  |     device = torch.device("cuda") | ||||||
|  |     x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) | ||||||
|  |  | ||||||
|  |     torch_out, torch_scale = torch_per_token_quant_fp8(x) | ||||||
|  |     cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) | ||||||
|  |  | ||||||
|  |     if torch.allclose( | ||||||
|  |         cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 | ||||||
|  |     ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): | ||||||
|  |         print("✅ All implementations match") | ||||||
|  |     else: | ||||||
|  |         print("❌ Implementations differ") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | batch_size_range = [1, 16, 32, 64, 128] | ||||||
|  | seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] | ||||||
|  |  | ||||||
|  | configs = list(itertools.product(batch_size_range, seq_len_range)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @triton.testing.perf_report( | ||||||
|  |     triton.testing.Benchmark( | ||||||
|  |         x_names=["batch_size", "seq_len"], | ||||||
|  |         x_vals=configs, | ||||||
|  |         line_arg="provider", | ||||||
|  |         line_vals=["torch", "cuda"], | ||||||
|  |         line_names=["Torch", "CUDA"], | ||||||
|  |         styles=[("blue", "-"), ("green", "-")], | ||||||
|  |         ylabel="us", | ||||||
|  |         plot_name="per-token-dynamic-quant-fp8-performance", | ||||||
|  |         args={}, | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def benchmark_quantization(batch_size, seq_len, provider): | ||||||
|  |     dtype = torch.float16 | ||||||
|  |     device = torch.device("cuda") | ||||||
|  |  | ||||||
|  |     x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) | ||||||
|  |  | ||||||
|  |     quantiles = [0.5, 0.2, 0.8] | ||||||
|  |  | ||||||
|  |     if provider == "torch": | ||||||
|  |         fn = lambda: torch_per_token_quant_fp8(x.clone()) | ||||||
|  |     elif provider == "cuda": | ||||||
|  |         fn = lambda: cuda_per_token_quant_fp8(x.clone()) | ||||||
|  |  | ||||||
|  |     ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) | ||||||
|  |  | ||||||
|  |     return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     calculate_diff(batch_size=4, seq_len=4096) | ||||||
|  |     benchmark_quantization.run(print_data=True) | ||||||
| @ -113,6 +113,7 @@ def bench_run( | |||||||
|         w2_scale: torch.Tensor, |         w2_scale: torch.Tensor, | ||||||
|         topk_weights: torch.Tensor, |         topk_weights: torch.Tensor, | ||||||
|         topk_ids: torch.Tensor, |         topk_ids: torch.Tensor, | ||||||
|  |         per_act_token: bool, | ||||||
|         num_repeats: int, |         num_repeats: int, | ||||||
|     ): |     ): | ||||||
|         for _ in range(num_repeats): |         for _ in range(num_repeats): | ||||||
| @ -124,7 +125,8 @@ def bench_run( | |||||||
|                 topk_ids, |                 topk_ids, | ||||||
|                 w1_scale, |                 w1_scale, | ||||||
|                 w2_scale, |                 w2_scale, | ||||||
|                 a1_scale=a_scale, |                 per_act_token, | ||||||
|  |                 a1_scale=None, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def run_cutlass_from_graph( |     def run_cutlass_from_graph( | ||||||
| @ -148,7 +150,8 @@ def bench_run( | |||||||
|                 topk_ids, |                 topk_ids, | ||||||
|                 w1_scale, |                 w1_scale, | ||||||
|                 w2_scale, |                 w2_scale, | ||||||
|                 a1_scale=a_scale, |                 per_act_token, | ||||||
|  |                 a1_scale=None, | ||||||
|             ) |             ) | ||||||
|  |  | ||||||
|     def run_triton_from_graph( |     def run_triton_from_graph( | ||||||
| @ -227,6 +230,7 @@ def bench_run( | |||||||
|         "w2_q": w2_q, |         "w2_q": w2_q, | ||||||
|         "w1_scale": w1_scale, |         "w1_scale": w1_scale, | ||||||
|         "w2_scale": w2_scale, |         "w2_scale": w2_scale, | ||||||
|  |         "per_act_token": per_act_token, | ||||||
|         # cuda graph params |         # cuda graph params | ||||||
|         "cutlass_graph": cutlass_graph, |         "cutlass_graph": cutlass_graph, | ||||||
|         "triton_graph": triton_graph, |         "triton_graph": triton_graph, | ||||||
| @ -287,12 +291,13 @@ def bench_run( | |||||||
|         w2_scale, |         w2_scale, | ||||||
|         topk_weights, |         topk_weights, | ||||||
|         topk_ids, |         topk_ids, | ||||||
|  |         per_act_token, | ||||||
|         num_warmup, |         num_warmup, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     results.append( |     results.append( | ||||||
|         benchmark.Timer( |         benchmark.Timer( | ||||||
|             stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)",  # noqa: E501 |             stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)",  # noqa: E501 | ||||||
|             globals=globals, |             globals=globals, | ||||||
|             label=label, |             label=label, | ||||||
|             sub_label=sub_label, |             sub_label=sub_label, | ||||||
|  | |||||||
| @ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: | |||||||
|  |  | ||||||
|         fn = lambda: ops.gptq_marlin_gemm( |         fn = lambda: ops.gptq_marlin_gemm( | ||||||
|             a=bt.a, |             a=bt.a, | ||||||
|  |             c=None, | ||||||
|             b_q_weight=w_q, |             b_q_weight=w_q, | ||||||
|             b_scales=w_s, |             b_scales=w_s, | ||||||
|  |             global_scale=None, | ||||||
|             b_zeros=w_zp, |             b_zeros=w_zp, | ||||||
|             g_idx=g_idx, |             g_idx=g_idx, | ||||||
|             perm=sort_indices, |             perm=sort_indices, | ||||||
|  | |||||||
| @ -22,8 +22,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | |||||||
|     MARLIN_SUPPORTED_GROUP_SIZES, |     MARLIN_SUPPORTED_GROUP_SIZES, | ||||||
|     query_marlin_supported_quant_types, |     query_marlin_supported_quant_types, | ||||||
| ) | ) | ||||||
|  | from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( | ||||||
|  |     FP4_MARLIN_SUPPORTED_GROUP_SIZES, | ||||||
|  |     rand_marlin_weight_fp4_like, | ||||||
|  | ) | ||||||
|  | from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( | ||||||
|  |     marlin_quant_fp8_torch, | ||||||
|  | ) | ||||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( | from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( | ||||||
|     MarlinWorkspace, |     MarlinWorkspace, | ||||||
|  |     awq_marlin_quantize, | ||||||
|     marlin_quantize, |     marlin_quantize, | ||||||
| ) | ) | ||||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( | from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( | ||||||
| @ -35,7 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( | |||||||
|     quantize_weights, |     quantize_weights, | ||||||
|     sort_weights, |     sort_weights, | ||||||
| ) | ) | ||||||
| from vllm.scalar_type import ScalarType | from vllm.scalar_type import ScalarType, scalar_types | ||||||
| from vllm.utils import FlexibleArgumentParser | from vllm.utils import FlexibleArgumentParser | ||||||
|  |  | ||||||
| DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] | DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] | ||||||
| @ -57,80 +65,144 @@ def bench_run( | |||||||
|     size_n: int, |     size_n: int, | ||||||
| ): | ): | ||||||
|     label = "Quant Matmul" |     label = "Quant Matmul" | ||||||
|  |  | ||||||
|     sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( |     sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( | ||||||
|         model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n |         model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     print(f"Testing: {sub_label}") |     print(f"Testing: {sub_label}") | ||||||
|  |  | ||||||
|     a = torch.randn(size_m, size_k).to(torch.half).cuda() |     a = torch.randn(size_m, size_k).to(torch.half).cuda() | ||||||
|     b = torch.rand(size_k, size_n).to(torch.half).cuda() |     b = torch.rand(size_k, size_n).to(torch.half).cuda() | ||||||
|  |     has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] | ||||||
|  |     if act_order and (group_size == -1 or group_size == size_k or has_zp): | ||||||
|  |         return | ||||||
|  |     if size_k % group_size != 0: | ||||||
|  |         return | ||||||
|  |  | ||||||
|     a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda() |     marlin_24_supported = ( | ||||||
|  |         quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES | ||||||
|     # Marlin quant |         and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES | ||||||
|     ( |  | ||||||
|         marlin_w_ref, |  | ||||||
|         marlin_q_w, |  | ||||||
|         marlin_s, |  | ||||||
|         marlin_g_idx, |  | ||||||
|         marlin_sort_indices, |  | ||||||
|         marlin_rand_perm, |  | ||||||
|     ) = marlin_quantize(b, quant_type, group_size, act_order) |  | ||||||
|  |  | ||||||
|     # Marlin_24 quant |  | ||||||
|     (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( |  | ||||||
|         marlin_24_quantize(b, quant_type, group_size) |  | ||||||
|     ) |     ) | ||||||
|  |     repack_supported = ( | ||||||
|     marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) |         quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES | ||||||
|  |         and group_size in MARLIN_SUPPORTED_GROUP_SIZES | ||||||
|     # GPTQ quant |  | ||||||
|     (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( |  | ||||||
|         b, quant_type, group_size, act_order |  | ||||||
|     ) |     ) | ||||||
|     q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) |     allspark_supported = ( | ||||||
|  |  | ||||||
|     # For act_order, sort the "weights" and "g_idx" |  | ||||||
|     # so that group ids are increasing |  | ||||||
|     repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) |  | ||||||
|     if act_order: |  | ||||||
|         (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) |  | ||||||
|  |  | ||||||
|     # Prepare |  | ||||||
|     marlin_workspace = MarlinWorkspace( |  | ||||||
|         size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     marlin_24_workspace = MarlinWorkspace( |  | ||||||
|         size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL |  | ||||||
|     ) |  | ||||||
|     marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) |  | ||||||
|  |  | ||||||
|     # AllSpark W8A16 quant |  | ||||||
|     as_supported_case = ( |  | ||||||
|         quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES |         quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES | ||||||
|         and group_size == -1 |         and group_size == -1 | ||||||
|         and not act_order |         and not act_order | ||||||
|         and is_k_full |         and is_k_full | ||||||
|     ) |     ) | ||||||
|     if as_supported_case: |  | ||||||
|         properties = torch.cuda.get_device_properties(b.device.index) |  | ||||||
|         sm_count = properties.multi_processor_count |  | ||||||
|         sm_version = properties.major * 10 + properties.minor |  | ||||||
|  |  | ||||||
|         supported_arch = sm_version >= 80 and sm_version < 90 |     def gen_marlin_params(): | ||||||
|         as_supported_case = as_supported_case and supported_arch |         # Marlin quant | ||||||
|         if supported_arch: |         marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None | ||||||
|             has_zp = False |         if quant_type == scalar_types.float4_e2m1f: | ||||||
|             w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) |             if group_size != 16 or act_order: | ||||||
|             qw = qw.to(torch.uint8) |                 return | ||||||
|  |             marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( | ||||||
|             qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( |                 b.T, group_size | ||||||
|                 qw, s, zp, has_zp |  | ||||||
|             ) |             ) | ||||||
|             CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD |         elif quant_type == scalar_types.float8_e4m3fn: | ||||||
|  |             if group_size not in [-1, 128] or act_order: | ||||||
|  |                 return | ||||||
|  |             marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size) | ||||||
|  |         elif group_size == 16: | ||||||
|  |             return | ||||||
|  |         elif has_zp: | ||||||
|  |             marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( | ||||||
|  |                 b, quant_type, group_size | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = ( | ||||||
|  |                 marlin_quantize(b, quant_type, group_size, act_order) | ||||||
|  |             ) | ||||||
|  |         return ( | ||||||
|  |             marlin_w_ref, | ||||||
|  |             marlin_q_w, | ||||||
|  |             marlin_s, | ||||||
|  |             marlin_s2, | ||||||
|  |             marlin_zp, | ||||||
|  |             marlin_g_idx, | ||||||
|  |             marlin_sort_indices, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def gen_marlin_24_params(): | ||||||
|  |         marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None | ||||||
|  |         if marlin_24_supported: | ||||||
|  |             (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( | ||||||
|  |                 marlin_24_quantize(b, quant_type, group_size) | ||||||
|  |             ) | ||||||
|  |         return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) | ||||||
|  |  | ||||||
|  |     def gen_repack_params(): | ||||||
|  |         q_w_gptq = None | ||||||
|  |         repack_sort_indices = None | ||||||
|  |         if repack_supported: | ||||||
|  |             (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( | ||||||
|  |                 b, quant_type, group_size, act_order | ||||||
|  |             ) | ||||||
|  |             q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) | ||||||
|  |  | ||||||
|  |             # For act_order, sort the "weights" and "g_idx" | ||||||
|  |             # so that group ids are increasing | ||||||
|  |             repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) | ||||||
|  |             if act_order: | ||||||
|  |                 (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) | ||||||
|  |         return q_w_gptq, repack_sort_indices | ||||||
|  |  | ||||||
|  |     def gen_allspark_params(): | ||||||
|  |         qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = ( | ||||||
|  |             CUBLAS_M_THRESHOLD | ||||||
|  |         ) = None | ||||||
|  |         nonlocal allspark_supported | ||||||
|  |         if allspark_supported: | ||||||
|  |             properties = torch.cuda.get_device_properties(b.device.index) | ||||||
|  |             sm_count = properties.multi_processor_count | ||||||
|  |             sm_version = properties.major * 10 + properties.minor | ||||||
|  |  | ||||||
|  |             supported_arch = sm_version >= 80 and sm_version < 90 | ||||||
|  |             allspark_supported = allspark_supported and supported_arch | ||||||
|  |             if supported_arch: | ||||||
|  |                 w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) | ||||||
|  |                 qw = qw.to(torch.uint8) | ||||||
|  |  | ||||||
|  |                 qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( | ||||||
|  |                     qw, s, zp, has_zp | ||||||
|  |                 ) | ||||||
|  |                 CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD | ||||||
|  |         return ( | ||||||
|  |             qw_reorder, | ||||||
|  |             s_reorder, | ||||||
|  |             zp_reorder, | ||||||
|  |             sm_count, | ||||||
|  |             sm_version, | ||||||
|  |             CUBLAS_M_THRESHOLD, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     ( | ||||||
|  |         marlin_w_ref, | ||||||
|  |         marlin_q_w, | ||||||
|  |         marlin_s, | ||||||
|  |         marlin_s2, | ||||||
|  |         marlin_zp, | ||||||
|  |         marlin_g_idx, | ||||||
|  |         marlin_sort_indices, | ||||||
|  |     ) = gen_marlin_params() | ||||||
|  |     marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = ( | ||||||
|  |         gen_marlin_24_params() | ||||||
|  |     ) | ||||||
|  |     q_w_gptq, repack_sort_indices = gen_repack_params() | ||||||
|  |     qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = ( | ||||||
|  |         gen_allspark_params() | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Prepare | ||||||
|  |     marlin_workspace = MarlinWorkspace( | ||||||
|  |         size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL | ||||||
|  |     ) | ||||||
|  |     marlin_24_workspace = MarlinWorkspace( | ||||||
|  |         size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     globals = { |     globals = { | ||||||
|         # Gen params |         # Gen params | ||||||
| @ -140,15 +212,14 @@ def bench_run( | |||||||
|         "size_n": size_n, |         "size_n": size_n, | ||||||
|         "size_k": size_k, |         "size_k": size_k, | ||||||
|         "a": a, |         "a": a, | ||||||
|         "a_tmp": a_tmp, |  | ||||||
|         # Marlin params |         # Marlin params | ||||||
|         "marlin_w_ref": marlin_w_ref, |         "marlin_w_ref": marlin_w_ref, | ||||||
|         "marlin_q_w": marlin_q_w, |         "marlin_q_w": marlin_q_w, | ||||||
|         "marlin_s": marlin_s, |         "marlin_s": marlin_s, | ||||||
|  |         "marlin_s2": marlin_s2, | ||||||
|         "marlin_zp": marlin_zp, |         "marlin_zp": marlin_zp, | ||||||
|         "marlin_g_idx": marlin_g_idx, |         "marlin_g_idx": marlin_g_idx, | ||||||
|         "marlin_sort_indices": marlin_sort_indices, |         "marlin_sort_indices": marlin_sort_indices, | ||||||
|         "marlin_rand_perm": marlin_rand_perm, |  | ||||||
|         "marlin_workspace": marlin_workspace, |         "marlin_workspace": marlin_workspace, | ||||||
|         "is_k_full": is_k_full, |         "is_k_full": is_k_full, | ||||||
|         # Marlin_24 params |         # Marlin_24 params | ||||||
| @ -161,12 +232,12 @@ def bench_run( | |||||||
|         "q_w_gptq": q_w_gptq, |         "q_w_gptq": q_w_gptq, | ||||||
|         "repack_sort_indices": repack_sort_indices, |         "repack_sort_indices": repack_sort_indices, | ||||||
|         # AllSpark W8A16 params |         # AllSpark W8A16 params | ||||||
|         "qw_reorder": qw_reorder if as_supported_case else None, |         "qw_reorder": qw_reorder, | ||||||
|         "s_reorder": s_reorder if as_supported_case else None, |         "s_reorder": s_reorder, | ||||||
|         "zp_reorder": zp_reorder if as_supported_case else None, |         "zp_reorder": zp_reorder, | ||||||
|         "sm_count": sm_count if as_supported_case else None, |         "sm_count": sm_count, | ||||||
|         "sm_version": sm_version if as_supported_case else None, |         "sm_version": sm_version, | ||||||
|         "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, |         "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, | ||||||
|         # Kernels |         # Kernels | ||||||
|         "gptq_marlin_gemm": ops.gptq_marlin_gemm, |         "gptq_marlin_gemm": ops.gptq_marlin_gemm, | ||||||
|         "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, |         "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, | ||||||
| @ -177,7 +248,7 @@ def bench_run( | |||||||
|     min_run_time = 1 |     min_run_time = 1 | ||||||
|  |  | ||||||
|     # Warmup pytorch |     # Warmup pytorch | ||||||
|     for i in range(5): |     for _ in range(5): | ||||||
|         torch.matmul(a, marlin_w_ref) |         torch.matmul(a, marlin_w_ref) | ||||||
|  |  | ||||||
|     results.append( |     results.append( | ||||||
| @ -192,17 +263,17 @@ def bench_run( | |||||||
|  |  | ||||||
|     results.append( |     results.append( | ||||||
|         benchmark.Timer( |         benchmark.Timer( | ||||||
|             stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501 |             stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501 | ||||||
|             globals=globals, |             globals=globals, | ||||||
|             label=label, |             label=label, | ||||||
|             sub_label=sub_label, |             sub_label=sub_label, | ||||||
|             description="gptq_marlin_gemm_fp16", |             description="gptq_marlin_gemm", | ||||||
|         ).blocked_autorange(min_run_time=min_run_time) |         ).blocked_autorange(min_run_time=min_run_time) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     results.append( |     results.append( | ||||||
|         benchmark.Timer( |         benchmark.Timer( | ||||||
|             stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501 |             stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501 | ||||||
|             globals=globals, |             globals=globals, | ||||||
|             label=label, |             label=label, | ||||||
|             sub_label=sub_label, |             sub_label=sub_label, | ||||||
| @ -210,10 +281,7 @@ def bench_run( | |||||||
|         ).blocked_autorange(min_run_time=min_run_time) |         ).blocked_autorange(min_run_time=min_run_time) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     if ( |     if marlin_24_supported: | ||||||
|         quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES |  | ||||||
|         and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES |  | ||||||
|     ): |  | ||||||
|         results.append( |         results.append( | ||||||
|             benchmark.Timer( |             benchmark.Timer( | ||||||
|                 stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)",  # noqa: E501 |                 stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)",  # noqa: E501 | ||||||
| @ -224,17 +292,18 @@ def bench_run( | |||||||
|             ).blocked_autorange(min_run_time=min_run_time) |             ).blocked_autorange(min_run_time=min_run_time) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     results.append( |     if repack_supported: | ||||||
|         benchmark.Timer( |         results.append( | ||||||
|             stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501 |             benchmark.Timer( | ||||||
|             globals=globals, |                 stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501 | ||||||
|             label=label, |                 globals=globals, | ||||||
|             sub_label=sub_label, |                 label=label, | ||||||
|             description="gptq_marlin_repack", |                 sub_label=sub_label, | ||||||
|         ).blocked_autorange(min_run_time=min_run_time) |                 description="gptq_marlin_repack", | ||||||
|     ) |             ).blocked_autorange(min_run_time=min_run_time) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     if as_supported_case: |     if allspark_supported: | ||||||
|         results.append( |         results.append( | ||||||
|             benchmark.Timer( |             benchmark.Timer( | ||||||
|                 stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501 |                 stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501 | ||||||
| @ -250,7 +319,6 @@ def main(args): | |||||||
|     print("Benchmarking models:") |     print("Benchmarking models:") | ||||||
|     for i, model in enumerate(args.models): |     for i, model in enumerate(args.models): | ||||||
|         print(f"[{i}]  {model}") |         print(f"[{i}]  {model}") | ||||||
|  |  | ||||||
|     results: list[benchmark.Measurement] = [] |     results: list[benchmark.Measurement] = [] | ||||||
|  |  | ||||||
|     for model in args.models: |     for model in args.models: | ||||||
| @ -278,14 +346,17 @@ def main(args): | |||||||
|                     ): |                     ): | ||||||
|                         continue |                         continue | ||||||
|  |  | ||||||
|                     for quant_type in query_marlin_supported_quant_types(False): |                     for quant_type in query_marlin_supported_quant_types(): | ||||||
|                         if ( |                         if ( | ||||||
|                             len(args.limit_num_bits) > 0 |                             len(args.limit_num_bits) > 0 | ||||||
|                             and quant_type.size_bits not in args.limit_num_bits |                             and quant_type.size_bits not in args.limit_num_bits | ||||||
|                         ): |                         ): | ||||||
|                             continue |                             continue | ||||||
|  |  | ||||||
|                         for group_size in MARLIN_SUPPORTED_GROUP_SIZES: |                         for group_size in ( | ||||||
|  |                             MARLIN_SUPPORTED_GROUP_SIZES | ||||||
|  |                             + FP4_MARLIN_SUPPORTED_GROUP_SIZES | ||||||
|  |                         ): | ||||||
|                             if ( |                             if ( | ||||||
|                                 len(args.limit_group_size) > 0 |                                 len(args.limit_group_size) > 0 | ||||||
|                                 and group_size not in args.limit_group_size |                                 and group_size not in args.limit_group_size | ||||||
|  | |||||||
| @ -86,6 +86,9 @@ def benchmark_config( | |||||||
|             (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 |             (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 | ||||||
|         ) |         ) | ||||||
|         w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) |         w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) | ||||||
|  |     if use_deep_gemm: | ||||||
|  |         # we use the default block shape for deepgemm | ||||||
|  |         block_quant_shape = [128, 128] | ||||||
|     if use_fp8_w8a8: |     if use_fp8_w8a8: | ||||||
|         if block_quant_shape: |         if block_quant_shape: | ||||||
|             block_n, block_k = block_quant_shape[0], block_quant_shape[1] |             block_n, block_k = block_quant_shape[0], block_quant_shape[1] | ||||||
| @ -620,7 +623,7 @@ def main(args: argparse.Namespace): | |||||||
|             4096, |             4096, | ||||||
|         ] |         ] | ||||||
|     else: |     else: | ||||||
|         batch_sizes = [args.batch_size] |         batch_sizes = args.batch_size | ||||||
|  |  | ||||||
|     use_deep_gemm = bool(args.use_deep_gemm) |     use_deep_gemm = bool(args.use_deep_gemm) | ||||||
|  |  | ||||||
| @ -728,7 +731,7 @@ if __name__ == "__main__": | |||||||
|     ) |     ) | ||||||
|     parser.add_argument("--use-deep-gemm", action="store_true") |     parser.add_argument("--use-deep-gemm", action="store_true") | ||||||
|     parser.add_argument("--seed", type=int, default=0) |     parser.add_argument("--seed", type=int, default=0) | ||||||
|     parser.add_argument("--batch-size", type=int, required=False) |     parser.add_argument("--batch-size", type=int, nargs="+", required=False) | ||||||
|     parser.add_argument("--tune", action="store_true") |     parser.add_argument("--tune", action="store_true") | ||||||
|     parser.add_argument("--trust-remote-code", action="store_true") |     parser.add_argument("--trust-remote-code", action="store_true") | ||||||
|     parser.add_argument("--model-prefix", type=str, required=False) |     parser.add_argument("--model-prefix", type=str, required=False) | ||||||
|  | |||||||
							
								
								
									
										159
									
								
								benchmarks/kernels/benchmark_moe_align_block_size.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								benchmarks/kernels/benchmark_moe_align_block_size.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,159 @@ | |||||||
|  | # SPDX-License-Identifier: Apache-2.0 | ||||||
|  | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|  | import argparse | ||||||
|  | import itertools | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from vllm import _custom_ops as ops | ||||||
|  | from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( | ||||||
|  |     moe_align_block_size_triton, | ||||||
|  | ) | ||||||
|  | from vllm.triton_utils import triton | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: | ||||||
|  |     return torch.stack( | ||||||
|  |         [ | ||||||
|  |             torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] | ||||||
|  |             for _ in range(num_tokens) | ||||||
|  |         ] | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8): | ||||||
|  |     """ | ||||||
|  |     Verifies vllm vs. Triton | ||||||
|  |     """ | ||||||
|  |     topk_ids = get_topk_ids(num_tokens, num_experts, topk) | ||||||
|  |  | ||||||
|  |     # 1. malloc space for triton and vllm | ||||||
|  |     # malloc enough space (max_num_tokens_padded) for the sorted ids | ||||||
|  |     max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) | ||||||
|  |     sorted_ids_triton = torch.empty( | ||||||
|  |         (max_num_tokens_padded,), dtype=torch.int32, device="cuda" | ||||||
|  |     ) | ||||||
|  |     sorted_ids_triton.fill_(topk_ids.numel())  # fill with sentinel value | ||||||
|  |     expert_ids_triton = torch.zeros( | ||||||
|  |         (max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda" | ||||||
|  |     ) | ||||||
|  |     num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda") | ||||||
|  |  | ||||||
|  |     sorted_ids_vllm = torch.empty_like(sorted_ids_triton) | ||||||
|  |     sorted_ids_vllm.fill_(topk_ids.numel()) | ||||||
|  |     expert_ids_vllm = torch.zeros_like(expert_ids_triton) | ||||||
|  |     num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton) | ||||||
|  |  | ||||||
|  |     # 2. run implementations | ||||||
|  |     moe_align_block_size_triton( | ||||||
|  |         topk_ids, | ||||||
|  |         num_experts, | ||||||
|  |         block_size, | ||||||
|  |         sorted_ids_triton, | ||||||
|  |         expert_ids_triton, | ||||||
|  |         num_tokens_post_pad_triton, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     ops.moe_align_block_size( | ||||||
|  |         topk_ids, | ||||||
|  |         num_experts, | ||||||
|  |         block_size, | ||||||
|  |         sorted_ids_vllm, | ||||||
|  |         expert_ids_vllm, | ||||||
|  |         num_tokens_post_pad_vllm, | ||||||
|  |     ) | ||||||
|  |     print(f"✅ VLLM implementation works with {num_experts} experts!") | ||||||
|  |  | ||||||
|  |     # 3. compare results | ||||||
|  |     if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose( | ||||||
|  |         num_tokens_post_pad_triton, num_tokens_post_pad_vllm | ||||||
|  |     ): | ||||||
|  |         print("✅ Triton and VLLM implementations match.") | ||||||
|  |     else: | ||||||
|  |         print("❌ Triton and VLLM implementations DO NOT match.") | ||||||
|  |         print("Triton expert_ids:", expert_ids_triton) | ||||||
|  |         print("VLLM expert_ids:", expert_ids_vllm) | ||||||
|  |         print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) | ||||||
|  |         print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # test configurations | ||||||
|  | num_tokens_range = [1, 16, 256, 4096] | ||||||
|  | num_experts_range = [16, 64, 224, 256, 280, 512] | ||||||
|  | topk_range = [1, 2, 8] | ||||||
|  | configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @triton.testing.perf_report( | ||||||
|  |     triton.testing.Benchmark( | ||||||
|  |         x_names=["num_tokens", "num_experts", "topk"], | ||||||
|  |         x_vals=configs, | ||||||
|  |         line_arg="provider", | ||||||
|  |         line_vals=["vllm", "triton"],  # "triton" | ||||||
|  |         line_names=["VLLM", "Triton"],  # "Triton" | ||||||
|  |         plot_name="moe-align-block-size-performance", | ||||||
|  |         args={}, | ||||||
|  |     ) | ||||||
|  | ) | ||||||
|  | def benchmark(num_tokens, num_experts, topk, provider): | ||||||
|  |     """Benchmark function for Triton.""" | ||||||
|  |     block_size = 256 | ||||||
|  |     topk_ids = get_topk_ids(num_tokens, num_experts, topk) | ||||||
|  |  | ||||||
|  |     max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) | ||||||
|  |     sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") | ||||||
|  |     sorted_ids.fill_(topk_ids.numel()) | ||||||
|  |     max_num_m_blocks = max_num_tokens_padded // block_size | ||||||
|  |     expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") | ||||||
|  |     num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda") | ||||||
|  |  | ||||||
|  |     quantiles = [0.5, 0.2, 0.8] | ||||||
|  |  | ||||||
|  |     if provider == "vllm": | ||||||
|  |         ms, min_ms, max_ms = triton.testing.do_bench( | ||||||
|  |             lambda: ops.moe_align_block_size( | ||||||
|  |                 topk_ids, | ||||||
|  |                 num_experts, | ||||||
|  |                 block_size, | ||||||
|  |                 sorted_ids.clone(), | ||||||
|  |                 expert_ids.clone(), | ||||||
|  |                 num_tokens_post_pad.clone(), | ||||||
|  |             ), | ||||||
|  |             quantiles=quantiles, | ||||||
|  |         ) | ||||||
|  |     elif provider == "triton": | ||||||
|  |         ms, min_ms, max_ms = triton.testing.do_bench( | ||||||
|  |             lambda: moe_align_block_size_triton( | ||||||
|  |                 topk_ids, | ||||||
|  |                 num_experts, | ||||||
|  |                 block_size, | ||||||
|  |                 sorted_ids.clone(), | ||||||
|  |                 expert_ids.clone(), | ||||||
|  |                 num_tokens_post_pad.clone(), | ||||||
|  |             ), | ||||||
|  |             quantiles=quantiles, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--num_experts", | ||||||
|  |         type=int, | ||||||
|  |         default=64, | ||||||
|  |         choices=[8, 16, 32, 64, 128, 256], | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--topk", | ||||||
|  |         type=int, | ||||||
|  |         default=8, | ||||||
|  |         choices=[2, 4, 8], | ||||||
|  |         help="Top-k value for correctness check.", | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  |  | ||||||
|  |     print("Running correctness check...") | ||||||
|  |     check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) | ||||||
|  |     benchmark.run(print_data=True, show_plots=True) | ||||||
							
								
								
									
										240
									
								
								benchmarks/kernels/benchmark_trtllm_attention.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								benchmarks/kernels/benchmark_trtllm_attention.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,240 @@ | |||||||
|  | # SPDX-License-Identifier: Apache-2.0 | ||||||
|  | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||||
|  |  | ||||||
|  | import csv | ||||||
|  | import os | ||||||
|  | import random | ||||||
|  | from datetime import datetime | ||||||
|  |  | ||||||
|  | import flashinfer | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | ||||||
|  |  | ||||||
|  | # KV Cache Layout for TRT-LLM | ||||||
|  | # kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def to_float8(x, dtype=torch.float8_e4m3fn): | ||||||
|  |     finfo = torch.finfo(dtype) | ||||||
|  |     min_val, max_val = x.aminmax() | ||||||
|  |     amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) | ||||||
|  |     scale = finfo.max / amax * 0.1 | ||||||
|  |     x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) | ||||||
|  |     return x_scl_sat.to(dtype), scale.float().reciprocal() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @torch.no_grad() | ||||||
|  | def benchmark_decode( | ||||||
|  |     num_seqs, | ||||||
|  |     max_seq_len, | ||||||
|  |     page_size=16, | ||||||
|  |     dtype=torch.bfloat16, | ||||||
|  |     kv_layout="HND", | ||||||
|  |     num_kv_heads=8, | ||||||
|  |     kv_cache_dtype="auto", | ||||||
|  |     head_dim=128, | ||||||
|  |     warmup=10, | ||||||
|  |     trials=20, | ||||||
|  | ): | ||||||
|  |     torch.set_default_device("cuda") | ||||||
|  |     device = "cuda" | ||||||
|  |     torch.manual_seed(0) | ||||||
|  |  | ||||||
|  |     # Currently only HEAD_GRP_SIZE == 8 is supported | ||||||
|  |     HEAD_GRP_SIZE = 8 | ||||||
|  |     MAX_SEQ_LEN = max_seq_len | ||||||
|  |  | ||||||
|  |     # large number to reduce kv_cache reuse | ||||||
|  |     NUM_BLOCKS = int(256000 / page_size) | ||||||
|  |  | ||||||
|  |     workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) | ||||||
|  |  | ||||||
|  |     # For decode, batch_size is num_decode_token | ||||||
|  |     num_qo_heads = num_kv_heads * HEAD_GRP_SIZE | ||||||
|  |     sm_scale = float(1.0 / (head_dim**0.5)) | ||||||
|  |     q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) | ||||||
|  |     kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] | ||||||
|  |  | ||||||
|  |     max_kv_len = max(kv_lens) | ||||||
|  |     kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) | ||||||
|  |     max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size | ||||||
|  |  | ||||||
|  |     block_tables = torch.randint( | ||||||
|  |         0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) | ||||||
|  |     kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) | ||||||
|  |     k_scale = v_scale = 1.0 | ||||||
|  |  | ||||||
|  |     if kv_cache_dtype.startswith("fp8"): | ||||||
|  |         kv_cache, _ = to_float8(kv_cache) | ||||||
|  |  | ||||||
|  |     # Benchmark TRT decode | ||||||
|  |     def trt_decode(): | ||||||
|  |         return flashinfer.decode.trtllm_batch_decode_with_kv_cache( | ||||||
|  |             q, | ||||||
|  |             kv_cache, | ||||||
|  |             workspace_buffer, | ||||||
|  |             num_qo_heads, | ||||||
|  |             num_kv_heads, | ||||||
|  |             sm_scale, | ||||||
|  |             block_tables, | ||||||
|  |             kv_lens_tensor, | ||||||
|  |             page_size, | ||||||
|  |             max_kv_len, | ||||||
|  |             kv_cache_dtype, | ||||||
|  |             k_scale, | ||||||
|  |             v_scale, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def time_fn(fn, warmup=10, trials=20): | ||||||
|  |         torch.cuda.synchronize() | ||||||
|  |         start = torch.cuda.Event(enable_timing=True) | ||||||
|  |         end = torch.cuda.Event(enable_timing=True) | ||||||
|  |         times = [] | ||||||
|  |         for i in range(warmup): | ||||||
|  |             fn() | ||||||
|  |         for i in range(trials): | ||||||
|  |             start.record() | ||||||
|  |             fn() | ||||||
|  |             end.record() | ||||||
|  |             torch.cuda.synchronize() | ||||||
|  |             times.append(start.elapsed_time(end))  # ms | ||||||
|  |         return sum(times) / len(times), torch.std(torch.tensor(times)) | ||||||
|  |  | ||||||
|  |     # TRT Decode | ||||||
|  |     trt_mean, trt_std = time_fn(trt_decode) | ||||||
|  |  | ||||||
|  |     kv_indptr = [0] | ||||||
|  |     kv_indices = [] | ||||||
|  |     kv_last_page_lens = [] | ||||||
|  |     for i in range(num_seqs): | ||||||
|  |         seq_len = kv_lens[i] | ||||||
|  |         assert seq_len > 0 | ||||||
|  |         num_blocks = (seq_len + page_size - 1) // page_size | ||||||
|  |         kv_indices.extend(block_tables[i, :num_blocks]) | ||||||
|  |         kv_indptr.append(kv_indptr[-1] + num_blocks) | ||||||
|  |         kv_last_page_len = seq_len % page_size | ||||||
|  |         if kv_last_page_len == 0: | ||||||
|  |             kv_last_page_len = page_size | ||||||
|  |         kv_last_page_lens.append(kv_last_page_len) | ||||||
|  |  | ||||||
|  |     kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) | ||||||
|  |     kv_indices = torch.tensor(kv_indices, dtype=torch.int32) | ||||||
|  |     kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) | ||||||
|  |  | ||||||
|  |     wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( | ||||||
|  |         workspace_buffer, | ||||||
|  |         kv_layout, | ||||||
|  |         use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     wrapper.plan( | ||||||
|  |         kv_indptr, | ||||||
|  |         kv_indices, | ||||||
|  |         kv_last_page_lens, | ||||||
|  |         num_qo_heads, | ||||||
|  |         num_kv_heads, | ||||||
|  |         head_dim, | ||||||
|  |         page_size, | ||||||
|  |         "NONE", | ||||||
|  |         q_data_type=dtype, | ||||||
|  |         kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     def baseline_decode(): | ||||||
|  |         return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale) | ||||||
|  |  | ||||||
|  |     baseline_mean, baseline_std = time_fn(baseline_decode) | ||||||
|  |  | ||||||
|  |     # Calculate percentage speedup (positive means TRT is faster) | ||||||
|  |     speedup_percent = (baseline_mean - trt_mean) / baseline_mean | ||||||
|  |  | ||||||
|  |     print( | ||||||
|  |         f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" | ||||||
|  |         f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # Return results for CSV writing | ||||||
|  |     return { | ||||||
|  |         "num_seqs": num_seqs, | ||||||
|  |         "trt_mean": trt_mean, | ||||||
|  |         "trt_std": trt_std.item(), | ||||||
|  |         "baseline_mean": baseline_mean, | ||||||
|  |         "baseline_std": baseline_std.item(), | ||||||
|  |         "speedup_percent": speedup_percent, | ||||||
|  |         "q_dtype": str(dtype), | ||||||
|  |         "kv_cache_dtype": kv_cache_dtype, | ||||||
|  |         "page_size": page_size, | ||||||
|  |         "num_kv_heads": num_kv_heads, | ||||||
|  |         "head_dim": head_dim, | ||||||
|  |         "max_seq_len": max_seq_len, | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def write_results_to_csv(results, filename=None): | ||||||
|  |     """Write benchmark results to CSV file.""" | ||||||
|  |     if filename is None: | ||||||
|  |         timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||||||
|  |         filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" | ||||||
|  |  | ||||||
|  |     fieldnames = [ | ||||||
|  |         "num_seqs", | ||||||
|  |         "trt_mean", | ||||||
|  |         "trt_std", | ||||||
|  |         "baseline_mean", | ||||||
|  |         "baseline_std", | ||||||
|  |         "speedup_percent", | ||||||
|  |         "q_dtype", | ||||||
|  |         "kv_cache_dtype", | ||||||
|  |         "page_size", | ||||||
|  |         "num_kv_heads", | ||||||
|  |         "head_dim", | ||||||
|  |         "max_seq_len", | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     file_exists = os.path.exists(filename) | ||||||
|  |  | ||||||
|  |     with open(filename, "a", newline="") as csvfile: | ||||||
|  |         writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | ||||||
|  |  | ||||||
|  |         if not file_exists: | ||||||
|  |             writer.writeheader() | ||||||
|  |  | ||||||
|  |         for result in results: | ||||||
|  |             writer.writerow(result) | ||||||
|  |  | ||||||
|  |     print(f"Results written to {filename}") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] | ||||||
|  |     max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] | ||||||
|  |     all_results = [] | ||||||
|  |  | ||||||
|  |     print("Running benchmark for kv_cache_dtype: bfloat16") | ||||||
|  |     print( | ||||||
|  |         "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" | ||||||
|  |     ) | ||||||
|  |     for max_seq_len in max_seq_lens: | ||||||
|  |         for bs in num_seqs: | ||||||
|  |             result = benchmark_decode( | ||||||
|  |                 bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto" | ||||||
|  |             ) | ||||||
|  |             all_results.append(result) | ||||||
|  |  | ||||||
|  |     print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8") | ||||||
|  |     print( | ||||||
|  |         "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" | ||||||
|  |     ) | ||||||
|  |     for max_seq_len in max_seq_lens: | ||||||
|  |         for bs in num_seqs: | ||||||
|  |             result = benchmark_decode( | ||||||
|  |                 bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8" | ||||||
|  |             ) | ||||||
|  |             all_results.append(result) | ||||||
|  |  | ||||||
|  |     # Write all results to CSV | ||||||
|  |     write_results_to_csv(all_results) | ||||||
| @ -85,12 +85,6 @@ def benchmark_shape(m: int, | |||||||
|  |  | ||||||
|     # === DeepGEMM Implementation === |     # === DeepGEMM Implementation === | ||||||
|     def deepgemm_gemm(): |     def deepgemm_gemm(): | ||||||
|         # A quantization is inside the loop as it depends on activations |  | ||||||
|         # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) |  | ||||||
|         # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( |  | ||||||
|         #     A, block_size[1]) |  | ||||||
|         # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) |  | ||||||
|         # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) |  | ||||||
|         deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), |         deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), | ||||||
|                                        (B_deepgemm, B_scale_deepgemm), |                                        (B_deepgemm, B_scale_deepgemm), | ||||||
|                                        C_deepgemm) |                                        C_deepgemm) | ||||||
| @ -98,8 +92,6 @@ def benchmark_shape(m: int, | |||||||
|  |  | ||||||
|     # === vLLM Triton Implementation === |     # === vLLM Triton Implementation === | ||||||
|     def vllm_triton_gemm(): |     def vllm_triton_gemm(): | ||||||
|         # A quantization is inside the loop as it depends on activations |  | ||||||
|         # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) |  | ||||||
|         return w8a8_block_fp8_matmul(A_vllm, |         return w8a8_block_fp8_matmul(A_vllm, | ||||||
|                                      B_vllm, |                                      B_vllm, | ||||||
|                                      A_scale_vllm, |                                      A_scale_vllm, | ||||||
| @ -109,9 +101,6 @@ def benchmark_shape(m: int, | |||||||
|  |  | ||||||
|     # === vLLM CUTLASS Implementation === |     # === vLLM CUTLASS Implementation === | ||||||
|     def vllm_cutlass_gemm(): |     def vllm_cutlass_gemm(): | ||||||
|         # A quantization is inside the loop as it depends on activations |  | ||||||
|         # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( |  | ||||||
|         #     A, block_size[1], column_major_scales=True) |  | ||||||
|         return ops.cutlass_scaled_mm(A_vllm_cutlass, |         return ops.cutlass_scaled_mm(A_vllm_cutlass, | ||||||
|                                      B_vllm.T, |                                      B_vllm.T, | ||||||
|                                      scale_a=A_scale_vllm_cutlass, |                                      scale_a=A_scale_vllm_cutlass, | ||||||
|  | |||||||
| @ -12,9 +12,8 @@ endif() | |||||||
| # | # | ||||||
| # Define environment variables for special configurations | # Define environment variables for special configurations | ||||||
| # | # | ||||||
| if(DEFINED ENV{VLLM_CPU_AVX512BF16}) | set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) | ||||||
|     set(ENABLE_AVX512BF16 ON) | set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) | ||||||
| endif() |  | ||||||
|  |  | ||||||
| include_directories("${CMAKE_SOURCE_DIR}/csrc") | include_directories("${CMAKE_SOURCE_DIR}/csrc") | ||||||
|  |  | ||||||
| @ -96,12 +95,30 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) | |||||||
|         if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND |         if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND | ||||||
|             CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) |             CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) | ||||||
|             list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") |             list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") | ||||||
|  |             set(ENABLE_AVX512BF16 ON) | ||||||
|         else() |         else() | ||||||
|  |             set(ENABLE_AVX512BF16 OFF) | ||||||
|             message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") |             message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") | ||||||
|         endif() |         endif() | ||||||
|     else() |     else() | ||||||
|  |         set(ENABLE_AVX512BF16 OFF) | ||||||
|         message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") |         message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") | ||||||
|     endif() |     endif() | ||||||
|  |  | ||||||
|  |     find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND) | ||||||
|  |     if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI) | ||||||
|  |         if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND | ||||||
|  |             CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) | ||||||
|  |             list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni") | ||||||
|  |             set(ENABLE_AVX512VNNI ON) | ||||||
|  |         else() | ||||||
|  |             set(ENABLE_AVX512VNNI OFF) | ||||||
|  |             message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3") | ||||||
|  |         endif() | ||||||
|  |     else() | ||||||
|  |         set(ENABLE_AVX512VNNI OFF) | ||||||
|  |         message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") | ||||||
|  |     endif() | ||||||
|      |      | ||||||
| elseif (AVX2_FOUND) | elseif (AVX2_FOUND) | ||||||
|     list(APPEND CXX_COMPILE_FLAGS "-mavx2") |     list(APPEND CXX_COMPILE_FLAGS "-mavx2") | ||||||
| @ -148,17 +165,32 @@ else() | |||||||
| endif() | endif() | ||||||
|  |  | ||||||
| # | # | ||||||
| # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) | # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) | ||||||
| # | # Flag to enable ACL kernels for AARCH64 platforms | ||||||
| if (AVX512_FOUND AND NOT AVX512_DISABLED) | if ( VLLM_BUILD_ACL STREQUAL "ON") | ||||||
|  |     set(USE_ACL ON) | ||||||
|  | else() | ||||||
|  |     set(USE_ACL OFF) | ||||||
|  | endif() | ||||||
|  |  | ||||||
|  | if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) | ||||||
|     FetchContent_Declare( |     FetchContent_Declare( | ||||||
|         oneDNN |         oneDNN | ||||||
|         GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git |         GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git | ||||||
|         GIT_TAG  v3.7.1 |         GIT_TAG  v3.8.1 | ||||||
|         GIT_PROGRESS TRUE |         GIT_PROGRESS TRUE | ||||||
|         GIT_SHALLOW TRUE |         GIT_SHALLOW TRUE | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |     if(USE_ACL) | ||||||
|  |         find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/) | ||||||
|  |         if(NOT ARM_COMPUTE_LIBRARY) | ||||||
|  |             message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR") | ||||||
|  |         endif() | ||||||
|  |         set(ONEDNN_AARCH64_USE_ACL "ON") | ||||||
|  |         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") | ||||||
|  |         endif() | ||||||
|  |  | ||||||
|     set(ONEDNN_LIBRARY_TYPE "STATIC") |     set(ONEDNN_LIBRARY_TYPE "STATIC") | ||||||
|     set(ONEDNN_BUILD_DOC "OFF") |     set(ONEDNN_BUILD_DOC "OFF") | ||||||
|     set(ONEDNN_BUILD_EXAMPLES "OFF") |     set(ONEDNN_BUILD_EXAMPLES "OFF") | ||||||
| @ -231,11 +263,29 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) | |||||||
|         "csrc/cpu/quant.cpp" |         "csrc/cpu/quant.cpp" | ||||||
|         "csrc/cpu/shm.cpp" |         "csrc/cpu/shm.cpp" | ||||||
|         ${VLLM_EXT_SRC}) |         ${VLLM_EXT_SRC}) | ||||||
|  |     if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) | ||||||
|  |         set(VLLM_EXT_SRC | ||||||
|  |             "csrc/cpu/sgl-kernels/gemm.cpp" | ||||||
|  |             "csrc/cpu/sgl-kernels/gemm_int8.cpp" | ||||||
|  |             "csrc/cpu/sgl-kernels/gemm_fp8.cpp" | ||||||
|  |             "csrc/cpu/sgl-kernels/moe.cpp" | ||||||
|  |             "csrc/cpu/sgl-kernels/moe_int8.cpp" | ||||||
|  |             "csrc/cpu/sgl-kernels/moe_fp8.cpp" | ||||||
|  |             ${VLLM_EXT_SRC}) | ||||||
|  |         add_compile_definitions(-DCPU_CAPABILITY_AVX512) | ||||||
|  |     endif() | ||||||
| elseif(POWER10_FOUND) | elseif(POWER10_FOUND) | ||||||
|     set(VLLM_EXT_SRC |     set(VLLM_EXT_SRC | ||||||
|         "csrc/cpu/quant.cpp" |         "csrc/cpu/quant.cpp" | ||||||
|         ${VLLM_EXT_SRC}) |         ${VLLM_EXT_SRC}) | ||||||
| endif() | endif() | ||||||
|  | if (ASIMD_FOUND) | ||||||
|  |     set(VLLM_EXT_SRC | ||||||
|  |         "csrc/cpu/quant.cpp" | ||||||
|  |         ${VLLM_EXT_SRC}) | ||||||
|  | endif() | ||||||
|  |  | ||||||
|  | message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") | ||||||
|  |  | ||||||
| # | # | ||||||
| # Define extension targets | # Define extension targets | ||||||
|  | |||||||
| @ -38,7 +38,7 @@ else() | |||||||
|   FetchContent_Declare( |   FetchContent_Declare( | ||||||
|           vllm-flash-attn |           vllm-flash-attn | ||||||
|           GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git |           GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git | ||||||
|           GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491 |           GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4 | ||||||
|           GIT_PROGRESS TRUE |           GIT_PROGRESS TRUE | ||||||
|           # Don't share the vllm-flash-attn build between build types |           # Don't share the vllm-flash-attn build between build types | ||||||
|           BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn |           BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn | ||||||
|  | |||||||
| @ -122,6 +122,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) | |||||||
|       "-DENABLE_FP8" |       "-DENABLE_FP8" | ||||||
|       "-U__HIP_NO_HALF_CONVERSIONS__" |       "-U__HIP_NO_HALF_CONVERSIONS__" | ||||||
|       "-U__HIP_NO_HALF_OPERATORS__" |       "-U__HIP_NO_HALF_OPERATORS__" | ||||||
|  |       "-Werror=unused-variable" | ||||||
|       "-fno-gpu-rdc") |       "-fno-gpu-rdc") | ||||||
|  |  | ||||||
|   endif() |   endif() | ||||||
| @ -264,8 +265,8 @@ macro(set_gencode_flags_for_srcs) | |||||||
| endmacro() | endmacro() | ||||||
|  |  | ||||||
| # | # | ||||||
| # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form  | # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form | ||||||
| #  `<major>.<minor>[letter]` compute the "loose intersection" with the  | #  `<major>.<minor>[letter]` compute the "loose intersection" with the | ||||||
| #  `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in | #  `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in | ||||||
| #  `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there | #  `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there | ||||||
| #  is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the | #  is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the | ||||||
| @ -277,7 +278,7 @@ endmacro() | |||||||
| #  in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. | #  in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. | ||||||
| # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is | # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is | ||||||
| #  in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add | #  in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add | ||||||
| #  x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).  | #  x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). | ||||||
| # The result is stored in `OUT_CUDA_ARCHS`. | # The result is stored in `OUT_CUDA_ARCHS`. | ||||||
| # | # | ||||||
| # Example: | # Example: | ||||||
| @ -312,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR | |||||||
|   # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should |   # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should | ||||||
|   # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS |   # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS | ||||||
|   set(_CUDA_ARCHS) |   set(_CUDA_ARCHS) | ||||||
|   if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) |   foreach(_arch ${_SRC_CUDA_ARCHS}) | ||||||
|     list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") |     if(_arch MATCHES "\\a$") | ||||||
|     if ("9.0" IN_LIST TGT_CUDA_ARCHS) |       list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") | ||||||
|       list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") |       string(REPLACE "a" "" _base "${_arch}") | ||||||
|       set(_CUDA_ARCHS "9.0a") |       if ("${_base}" IN_LIST TGT_CUDA_ARCHS) | ||||||
|  |         list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") | ||||||
|  |         list(APPEND _CUDA_ARCHS "${_arch}") | ||||||
|  |       endif() | ||||||
|     endif() |     endif() | ||||||
|   endif() |   endforeach() | ||||||
|  |  | ||||||
|   if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) |  | ||||||
|     list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") |  | ||||||
|     if ("10.0" IN_LIST TGT_CUDA_ARCHS) |  | ||||||
|       list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") |  | ||||||
|       set(_CUDA_ARCHS "10.0a") |  | ||||||
|     endif() |  | ||||||
|   endif() |  | ||||||
|  |  | ||||||
|   list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) |   list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) | ||||||
|  |  | ||||||
| @ -358,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR | |||||||
|   endforeach() |   endforeach() | ||||||
|  |  | ||||||
|   list(REMOVE_DUPLICATES _CUDA_ARCHS) |   list(REMOVE_DUPLICATES _CUDA_ARCHS) | ||||||
|    |  | ||||||
|   # reapply +PTX suffix to architectures that requested PTX |   # reapply +PTX suffix to architectures that requested PTX | ||||||
|   set(_FINAL_ARCHS) |   set(_FINAL_ARCHS) | ||||||
|   foreach(_arch ${_CUDA_ARCHS}) |   foreach(_arch ${_CUDA_ARCHS}) | ||||||
| @ -369,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR | |||||||
|     endif() |     endif() | ||||||
|   endforeach() |   endforeach() | ||||||
|   set(_CUDA_ARCHS ${_FINAL_ARCHS}) |   set(_CUDA_ARCHS ${_FINAL_ARCHS}) | ||||||
|    |  | ||||||
|   set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) |   set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) | ||||||
| endfunction() | endfunction() | ||||||
|  |  | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ | |||||||
|  |  | ||||||
| #include "attention_dtypes.h" | #include "attention_dtypes.h" | ||||||
| #include "attention_utils.cuh" | #include "attention_utils.cuh" | ||||||
|  | #include "cuda_compat.h" | ||||||
|  |  | ||||||
| #ifdef USE_ROCM | #ifdef USE_ROCM | ||||||
|   #include <hip/hip_bf16.h> |   #include <hip/hip_bf16.h> | ||||||
| @ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16; | |||||||
|   #include "../quantization/fp8/nvidia/quant_utils.cuh" |   #include "../quantization/fp8/nvidia/quant_utils.cuh" | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #ifndef USE_ROCM |  | ||||||
|   #define WARP_SIZE 32 |  | ||||||
| #else |  | ||||||
|   #define WARP_SIZE warpSize |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||||
| #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) | #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) | ||||||
| @ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel( | |||||||
|  |  | ||||||
| }  // namespace vllm | }  // namespace vllm | ||||||
|  |  | ||||||
| #undef WARP_SIZE |  | ||||||
| #undef MAX | #undef MAX | ||||||
| #undef MIN | #undef MIN | ||||||
| #undef DIVIDE_ROUND_UP | #undef DIVIDE_ROUND_UP | ||||||
|  | |||||||
| @ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, | |||||||
|               "page_table must be a 32-bit integer tensor"); |               "page_table must be a 32-bit integer tensor"); | ||||||
|  |  | ||||||
|   auto in_dtype = q_nope.dtype(); |   auto in_dtype = q_nope.dtype(); | ||||||
|   at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; |   const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); | ||||||
|   const cudaStream_t stream = |   const cudaStream_t stream = | ||||||
|       at::cuda::getCurrentCUDAStream(q_nope.get_device()); |       at::cuda::getCurrentCUDAStream(q_nope.get_device()); | ||||||
|   if (in_dtype == at::ScalarType::Half) { |   if (in_dtype == at::ScalarType::Half) { | ||||||
|  | |||||||
							
								
								
									
										372
									
								
								csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										372
									
								
								csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,372 @@ | |||||||
|  | /*************************************************************************************************** | ||||||
|  |  * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||
|  |  * SPDX-License-Identifier: BSD-3-Clause | ||||||
|  |  * | ||||||
|  |  * Redistribution and use in source and binary forms, with or without | ||||||
|  |  * modification, are permitted provided that the following conditions are met: | ||||||
|  |  * | ||||||
|  |  * 1. Redistributions of source code must retain the above copyright notice, | ||||||
|  |  *this list of conditions and the following disclaimer. | ||||||
|  |  * | ||||||
|  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | ||||||
|  |  * this list of conditions and the following disclaimer in the documentation | ||||||
|  |  * and/or other materials provided with the distribution. | ||||||
|  |  * | ||||||
|  |  * 3. Neither the name of the copyright holder nor the names of its | ||||||
|  |  * contributors may be used to endorse or promote products derived from | ||||||
|  |  * this software without specific prior written permission. | ||||||
|  |  * | ||||||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||||
|  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||||
|  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||||
|  |  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||||||
|  |  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||||||
|  |  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||||||
|  |  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||||||
|  |  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||||||
|  |  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||||||
|  |  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||||||
|  |  *POSSIBILITY OF SUCH DAMAGE. | ||||||
|  |  * | ||||||
|  |  **************************************************************************************************/ | ||||||
|  | /* | ||||||
|  |  * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 | ||||||
|  |  * by Alcanderian JieXin Liang | ||||||
|  |  */ | ||||||
|  |  | ||||||
|  | /*! | ||||||
|  |  \file | ||||||
|  |  \brief An universal device layer for cutlass 3.x-style kernels. | ||||||
|  | */ | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | // common | ||||||
|  | #include "cutlass/cutlass.h" | ||||||
|  | #include "cutlass/device_kernel.h" | ||||||
|  |  | ||||||
|  | #if !defined(__CUDACC_RTC__) | ||||||
|  | #include "cutlass/cluster_launch.hpp" | ||||||
|  | #include "cutlass/trace.h" | ||||||
|  | #endif // !defined(__CUDACC_RTC__) | ||||||
|  |  | ||||||
|  | #include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp" | ||||||
|  | #include "../kernel/sm100_fmha_mla_reduction.hpp" | ||||||
|  |  | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | namespace cutlass::fmha::device { | ||||||
|  |  | ||||||
|  | using namespace cute; | ||||||
|  | using namespace cutlass::fmha::kernel; | ||||||
|  |  | ||||||
|  |  | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
|  | ////////////////////////////// CUTLASS 3.x API ///////////////////////////////// | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | template< | ||||||
|  |     class Kernel_ | ||||||
|  | > | ||||||
|  | class MLA { | ||||||
|  | public: | ||||||
|  |  | ||||||
|  |   using Kernel = Kernel_; | ||||||
|  |  | ||||||
|  |   using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< | ||||||
|  |       typename Kernel::ElementOut, | ||||||
|  |       typename Kernel::ElementAcc, | ||||||
|  |       typename Kernel::ElementAcc, | ||||||
|  |       Kernel::TileShapeH::value, | ||||||
|  |       Kernel::TileShapeL::value, | ||||||
|  |       256 /*Max split*/ | ||||||
|  |   >; | ||||||
|  |  | ||||||
|  |   /// Argument structure: User API | ||||||
|  |   using KernelArguments = typename Kernel::Arguments; | ||||||
|  |   using ReductionArguments = typename ReductionKernel::Arguments; | ||||||
|  |  | ||||||
|  |   using Arguments = KernelArguments; | ||||||
|  |  | ||||||
|  |   /// Argument structure: Kernel API | ||||||
|  |   using KernelParams = typename Kernel::Params; | ||||||
|  |   using ReductionParams = typename ReductionKernel::Params; | ||||||
|  |   struct Params { | ||||||
|  |     KernelParams fmha_params; | ||||||
|  |     ReductionParams reduction_params; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  | private: | ||||||
|  |  | ||||||
|  |   /// Kernel API parameters object | ||||||
|  |   Params params_; | ||||||
|  |  | ||||||
|  |   bool is_initialized(bool set = false) { | ||||||
|  |     static bool initialized = false; | ||||||
|  |     if (set) initialized = true; | ||||||
|  |     return initialized; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static ReductionArguments to_reduction_args(Arguments const& args) { | ||||||
|  |     auto [H, K, D, B] = args.problem_shape; | ||||||
|  |     return ReductionArguments{ | ||||||
|  |       nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, | ||||||
|  |       args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, | ||||||
|  |       args.ptr_split_kv, Kernel::TileShapeS::value | ||||||
|  |     }; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | public: | ||||||
|  |  | ||||||
|  |   /// Access the Params structure | ||||||
|  |   Params const& params() const { | ||||||
|  |     return params_; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static void set_split_kv (KernelArguments& args) { | ||||||
|  |     // printf("set_split_kv start"); | ||||||
|  |     if (args.split_kv >= 1) return; | ||||||
|  |     auto [H, K, D, B] = args.problem_shape; | ||||||
|  |     // std::cout << H << " " << K << " " << D << " " << B << "\n";       | ||||||
|  |     int sm_count = args.hw_info.sm_count; | ||||||
|  |     // printf("    sm_count = %d\n", sm_count); | ||||||
|  |     int max_splits = ceil_div(K, 128); | ||||||
|  |     max_splits = min(16, max_splits); | ||||||
|  |     // printf("    max_splits = %d\n", max_splits); | ||||||
|  |     int sms_per_batch = max(1, sm_count / B); | ||||||
|  |     // printf("    sms_per_batch = %d\n", sms_per_batch); | ||||||
|  |     int split_heur = min(max_splits, sms_per_batch); | ||||||
|  |     int waves = ceil_div(B * split_heur, sm_count); | ||||||
|  |     int k_waves = ceil_div(max_splits, split_heur); | ||||||
|  |     int split_wave_aware = ceil_div(max_splits, k_waves); | ||||||
|  |     args.split_kv = split_wave_aware; | ||||||
|  |     // printf("    args.split_kv = %d\n", args.split_kv); | ||||||
|  |  | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Determines whether the GEMM can execute the given problem. | ||||||
|  |   static Status | ||||||
|  |   can_implement(Arguments const& args) { | ||||||
|  |     if (! Kernel::can_implement(args)) { | ||||||
|  |       return Status::kInvalid; | ||||||
|  |     } | ||||||
|  |     if (! ReductionKernel::can_implement(to_reduction_args(args))) { | ||||||
|  |       return Status::kInvalid; | ||||||
|  |     } | ||||||
|  |     return Status::kSuccess; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Gets the workspace size | ||||||
|  |   static size_t | ||||||
|  |   get_workspace_size(Arguments const& args) { | ||||||
|  |     size_t workspace_bytes = 0; | ||||||
|  |     workspace_bytes += Kernel::get_workspace_size(args); | ||||||
|  |     workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); | ||||||
|  |     return workspace_bytes; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Computes the maximum number of active blocks per multiprocessor | ||||||
|  |   static int maximum_active_blocks(int /* smem_capacity */ = -1) { | ||||||
|  |     CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); | ||||||
|  |     int max_active_blocks = -1; | ||||||
|  |     int smem_size = Kernel::SharedStorageSize; | ||||||
|  |  | ||||||
|  |     // first, account for dynamic smem capacity if needed | ||||||
|  |     cudaError_t result; | ||||||
|  |     if (smem_size >= (48 << 10)) { | ||||||
|  |       CUTLASS_TRACE_HOST("  Setting smem size to " << smem_size); | ||||||
|  |       result = cudaFuncSetAttribute( | ||||||
|  |           device_kernel<Kernel>, | ||||||
|  |           cudaFuncAttributeMaxDynamicSharedMemorySize, | ||||||
|  |           smem_size); | ||||||
|  |       if (cudaSuccess != result) { | ||||||
|  |         result = cudaGetLastError(); // to clear the error bit | ||||||
|  |         CUTLASS_TRACE_HOST( | ||||||
|  |           "  cudaFuncSetAttribute() returned error: " | ||||||
|  |           << cudaGetErrorString(result)); | ||||||
|  |         return -1; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // query occupancy after setting smem size | ||||||
|  |     result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||||||
|  |         &max_active_blocks, | ||||||
|  |         device_kernel<Kernel>, | ||||||
|  |         Kernel::MaxThreadsPerBlock, | ||||||
|  |         smem_size); | ||||||
|  |  | ||||||
|  |     if (cudaSuccess != result) { | ||||||
|  |       result = cudaGetLastError(); // to clear the error bit | ||||||
|  |       CUTLASS_TRACE_HOST( | ||||||
|  |         "  cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " | ||||||
|  |         << cudaGetErrorString(result)); | ||||||
|  |       return -1; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     CUTLASS_TRACE_HOST("  max_active_blocks: " << max_active_blocks); | ||||||
|  |     return max_active_blocks; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Initializes GEMM state from arguments. | ||||||
|  |   Status | ||||||
|  |   initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { | ||||||
|  |     CUTLASS_TRACE_HOST("MLA::initialize() - workspace " | ||||||
|  |       << workspace << ", stream: " << (stream ? "non-null" : "null")); | ||||||
|  |  | ||||||
|  |     // Initialize the workspace | ||||||
|  |     Status status = Kernel::initialize_workspace(args, workspace, stream); | ||||||
|  |     if (status != Status::kSuccess) { | ||||||
|  |       return status; | ||||||
|  |     } | ||||||
|  |     status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); | ||||||
|  |     if (status != Status::kSuccess) { | ||||||
|  |       return status; | ||||||
|  |     } | ||||||
|  |     KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); | ||||||
|  |  | ||||||
|  |     ReductionArguments reduction_args = to_reduction_args(args); | ||||||
|  |     if (reduction_args.split_kv > 1) { | ||||||
|  |       reduction_args.ptr_oaccum   = kernel_params.epilogue.ptr_o_acc; | ||||||
|  |       reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; | ||||||
|  |     } | ||||||
|  |     ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); | ||||||
|  |     // Initialize the Params structure | ||||||
|  |     params_ = Params {kernel_params, reduction_params}; | ||||||
|  |  | ||||||
|  |     if (is_initialized()) return Status::kSuccess; | ||||||
|  |  | ||||||
|  |     // account for dynamic smem capacity if needed | ||||||
|  |     // no dynamic smem is needed for reduction kernel | ||||||
|  |     int smem_size = Kernel::SharedStorageSize; | ||||||
|  |     if (smem_size >= (48 << 10)) { | ||||||
|  |       CUTLASS_TRACE_HOST("  Setting smem size to " << smem_size); | ||||||
|  |       cudaError_t result = cudaFuncSetAttribute( | ||||||
|  |           device_kernel<Kernel>, | ||||||
|  |           cudaFuncAttributeMaxDynamicSharedMemorySize, | ||||||
|  |           smem_size); | ||||||
|  |       if (cudaSuccess != result) { | ||||||
|  |         result = cudaGetLastError(); // to clear the error bit | ||||||
|  |         CUTLASS_TRACE_HOST("  cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); | ||||||
|  |         return Status::kErrorInternal; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     is_initialized(true); | ||||||
|  |  | ||||||
|  |     return Status::kSuccess; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. | ||||||
|  |   Status | ||||||
|  |   update(Arguments const& args, void* workspace = nullptr) { | ||||||
|  |     CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); | ||||||
|  |  | ||||||
|  |     size_t workspace_bytes = get_workspace_size(args); | ||||||
|  |     if (workspace_bytes > 0 && nullptr == workspace) { | ||||||
|  |       return Status::kErrorWorkspaceNull; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto fmha_params = Kernel::to_underlying_arguments(args, workspace); | ||||||
|  |  | ||||||
|  |     ReductionArguments reduction_args = to_reduction_args(args); | ||||||
|  |     if (reduction_args.split_kv > 1) { | ||||||
|  |       reduction_args.ptr_oaccum   = fmha_params.epilogue.ptr_o_acc; | ||||||
|  |       reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; | ||||||
|  |     } | ||||||
|  |     ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); | ||||||
|  |     // Initialize the Params structure | ||||||
|  |     params_ = Params {fmha_params, reduction_params}; | ||||||
|  |  | ||||||
|  |     return Status::kSuccess; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Primary run() entry point API that is static allowing users to create and manage their own params. | ||||||
|  |   /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() | ||||||
|  |   static Status | ||||||
|  |   run(Params& params, cudaStream_t stream = nullptr) { | ||||||
|  |     CUTLASS_TRACE_HOST("MLA::run()"); | ||||||
|  |     dim3 const block = Kernel::get_block_shape(); | ||||||
|  |     dim3 const grid = Kernel::get_grid_shape(params.fmha_params); | ||||||
|  |  | ||||||
|  |     // configure smem size and carveout | ||||||
|  |     int smem_size = Kernel::SharedStorageSize; | ||||||
|  |  | ||||||
|  |     Status launch_result; | ||||||
|  |     // Use extended launch API only for mainloops that use it | ||||||
|  |     if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { | ||||||
|  |       dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), | ||||||
|  |                    cute::size<1>(typename Kernel::ClusterShape{}), | ||||||
|  |                    cute::size<2>(typename Kernel::ClusterShape{})); | ||||||
|  |       void const* kernel = (void const*) device_kernel<Kernel>; | ||||||
|  |       void* kernel_params[] = {¶ms.fmha_params}; | ||||||
|  |       launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |       launch_result = Status::kSuccess; | ||||||
|  |       device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     cudaError_t result = cudaGetLastError(); | ||||||
|  |     if (cudaSuccess != result or Status::kSuccess != launch_result) { | ||||||
|  |       //return Status::kSuccess; | ||||||
|  |       CUTLASS_TRACE_HOST("  Kernel launch failed. Reason: " << result); | ||||||
|  |       return Status::kErrorInternal; | ||||||
|  |     } | ||||||
|  |     if (params.reduction_params.split_kv > 1) { | ||||||
|  |       // launch reduction kernel | ||||||
|  |       dim3 const block = ReductionKernel::get_block_shape(); | ||||||
|  |       dim3 const grid  = ReductionKernel::get_grid_shape(params.reduction_params); | ||||||
|  |       device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params); | ||||||
|  |       cudaError_t result = cudaGetLastError(); | ||||||
|  |       if (cudaSuccess == result) { | ||||||
|  |         return Status::kSuccess; | ||||||
|  |       } | ||||||
|  |       else { | ||||||
|  |         CUTLASS_TRACE_HOST("  Kernel launch failed. Reason: " << result); | ||||||
|  |         return Status::kErrorInternal; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |       return Status::kSuccess; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // | ||||||
|  |   // Non-static launch overloads that first create and set the internal params struct of this kernel handle. | ||||||
|  |   // | ||||||
|  |  | ||||||
|  |   /// Launches the kernel after first constructing Params internal state from supplied arguments. | ||||||
|  |   Status | ||||||
|  |   run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { | ||||||
|  |     Status status = initialize(args, workspace, stream); | ||||||
|  |     if (Status::kSuccess == status) { | ||||||
|  |       status = run(params_, stream); | ||||||
|  |     } | ||||||
|  |     return status; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Launches the kernel after first constructing Params internal state from supplied arguments. | ||||||
|  |   Status | ||||||
|  |   operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { | ||||||
|  |     return run(args, workspace, stream); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Overload that allows a user to re-launch the same kernel without updating internal params struct. | ||||||
|  |   Status | ||||||
|  |   run(cudaStream_t stream = nullptr) { | ||||||
|  |     return run(params_, stream); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /// Overload that allows a user to re-launch the same kernel without updating internal params struct. | ||||||
|  |   Status | ||||||
|  |   operator()(cudaStream_t stream = nullptr) { | ||||||
|  |     return run(params_, stream); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | } // namespace cutlass::fmha::device | ||||||
|  |  | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
| @ -0,0 +1,203 @@ | |||||||
|  | /*************************************************************************************************** | ||||||
|  |  * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights | ||||||
|  |  *reserved. SPDX-License-Identifier: BSD-3-Clause | ||||||
|  |  * | ||||||
|  |  * Redistribution and use in source and binary forms, with or without | ||||||
|  |  * modification, are permitted provided that the following conditions are met: | ||||||
|  |  * | ||||||
|  |  * 1. Redistributions of source code must retain the above copyright notice, | ||||||
|  |  *this list of conditions and the following disclaimer. | ||||||
|  |  * | ||||||
|  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | ||||||
|  |  * this list of conditions and the following disclaimer in the documentation | ||||||
|  |  * and/or other materials provided with the distribution. | ||||||
|  |  * | ||||||
|  |  * 3. Neither the name of the copyright holder nor the names of its | ||||||
|  |  * contributors may be used to endorse or promote products derived from | ||||||
|  |  * this software without specific prior written permission. | ||||||
|  |  * | ||||||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||||
|  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||||
|  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||||
|  |  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||||||
|  |  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||||||
|  |  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||||||
|  |  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||||||
|  |  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||||||
|  |  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||||||
|  |  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||||||
|  |  *POSSIBILITY OF SUCH DAMAGE. | ||||||
|  |  * | ||||||
|  |  **************************************************************************************************/ | ||||||
|  | /* | ||||||
|  |  * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 | ||||||
|  |  * by Alcanderian JieXin Liang | ||||||
|  |  */ | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "cutlass/cutlass.h" | ||||||
|  | #include "cutlass/arch/arch.h" | ||||||
|  | #include "cute/tensor.hpp" | ||||||
|  |  | ||||||
|  | namespace cutlass::fmha::kernel { | ||||||
|  |  | ||||||
|  | using namespace cute; | ||||||
|  | template< | ||||||
|  |     class ElementOut, | ||||||
|  |     class ElementAcc, | ||||||
|  |     class ElementScale, | ||||||
|  |     size_t kNumHeads, | ||||||
|  |     size_t kHeadDimLatent, | ||||||
|  |     int kMaxSplits | ||||||
|  | > | ||||||
|  | struct Sm100FmhaMlaReductionKernel { | ||||||
|  |  | ||||||
|  |   static const int SharedStorageSize = 0; | ||||||
|  |   static const int MaxThreadsPerBlock = 128; | ||||||
|  |   static const int MinBlocksPerMultiprocessor = 1; | ||||||
|  |  | ||||||
|  |   using ArchTag = cutlass::arch::Sm100; | ||||||
|  |  | ||||||
|  |   static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); | ||||||
|  |   struct Arguments { | ||||||
|  |     ElementAcc* ptr_oaccum = nullptr; | ||||||
|  |     ElementOut* ptr_o = nullptr; | ||||||
|  |     ElementAcc* ptr_lseaccum = nullptr; | ||||||
|  |     ElementAcc* ptr_lse = nullptr; | ||||||
|  |     ElementScale scale = 1.f; | ||||||
|  |     int num_batches = 0; | ||||||
|  |     int split_kv = -1; | ||||||
|  |     int dim_k = -1; | ||||||
|  |     int* ptr_seq = nullptr; | ||||||
|  |     int* ptr_split_kv = nullptr; | ||||||
|  |     int tile_shape_s = 128; | ||||||
|  |   }; | ||||||
|  |   using Params = Arguments; | ||||||
|  |  | ||||||
|  |   static Params to_underlying_arguments(Arguments const& args, void* workspace) { | ||||||
|  |     return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, | ||||||
|  | 	    args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, | ||||||
|  | 	    args.ptr_split_kv, args.tile_shape_s}; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static size_t get_workspace_size(Arguments const& /*args*/) { | ||||||
|  |     return 0; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static Status initialize_workspace( | ||||||
|  |       Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { | ||||||
|  |     return Status::kSuccess; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static dim3 get_grid_shape(Params const& params) { | ||||||
|  |     return dim3(kNumHeads, 1, params.num_batches); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static dim3 get_block_shape() { | ||||||
|  |     return dim3(MaxThreadsPerBlock, 1, 1); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static bool can_implement(Arguments const& args) { | ||||||
|  |     if (args.num_batches <= 0) return false; | ||||||
|  |     if (args.split_kv <= 0) return false; | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { | ||||||
|  |     if (params.split_kv <= 1) return; | ||||||
|  |     auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); | ||||||
|  |  | ||||||
|  |     __shared__ ElementAcc sLseScale[kMaxSplits]; | ||||||
|  |     const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); | ||||||
|  |     const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); | ||||||
|  |  | ||||||
|  |     Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), | ||||||
|  |                                    make_shape(params.split_kv), Stride<Int<kNumHeads>>{}); | ||||||
|  |  | ||||||
|  |     Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), | ||||||
|  |                               Shape<_1>{}, Stride<_1>{}); | ||||||
|  |  | ||||||
|  |     auto dim_k = params.ptr_seq == nullptr ?  params.dim_k : params.ptr_seq[get<2>(blk_coord)]; | ||||||
|  |     auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; | ||||||
|  |     auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); | ||||||
|  |     auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); | ||||||
|  |     local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); | ||||||
|  |  | ||||||
|  |     int warp_idx = cutlass::canonical_warp_idx_sync(); | ||||||
|  |     if (warp_idx == 0) { | ||||||
|  |       constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); | ||||||
|  |  | ||||||
|  |       ElementAcc local_lse[kNLsePerThread]; | ||||||
|  |  | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int i = 0; i < kNLsePerThread; ++i) { | ||||||
|  |         const int split = i * 32 + threadIdx.x; | ||||||
|  |         local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity(); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity(); | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int i = 0; i < kNLsePerThread; ++i) { | ||||||
|  |         lse_max = max(lse_max, local_lse[i]); | ||||||
|  |       } | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int offset = 16; offset >= 1; offset /= 2) { | ||||||
|  |         lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset)); | ||||||
|  |       } | ||||||
|  |       lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max;  // In case all local LSEs are -inf | ||||||
|  |       lse_max = __shfl_sync(0xffffffff, lse_max, 0); | ||||||
|  |  | ||||||
|  |       ElementAcc sum_lse = 0; | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int i = 0; i < kNLsePerThread; ++i) { | ||||||
|  |         sum_lse = sum_lse + expf(local_lse[i] - lse_max); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int offset = 16; offset >= 1; offset /= 2) { | ||||||
|  |         sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); | ||||||
|  |  | ||||||
|  |       ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max; | ||||||
|  |       if (threadIdx.x == 0 and params.ptr_lse != nullptr) { | ||||||
|  |         gLSE(0) = global_lse; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for (int i = 0; i < kNLsePerThread; ++i) { | ||||||
|  |         const int split = i * 32 + threadIdx.x; | ||||||
|  |         if (split < local_split_kv) { | ||||||
|  |           sLseScale[split] = expf(local_lse[i] - global_lse); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; | ||||||
|  |     const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); | ||||||
|  |     Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), | ||||||
|  |                                Shape<Int<kHeadDimLatent>>{}, Stride<_1>{}); | ||||||
|  |     ElementAcc local_val[Elements] = {0}; | ||||||
|  |     for (int split = 0; split < local_split_kv; ++split) { | ||||||
|  |       ElementAcc lse_scale = sLseScale[split]; | ||||||
|  |       CUTLASS_PRAGMA_UNROLL | ||||||
|  |       for(int i = 0; i < Elements; ++i) { | ||||||
|  |         local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); | ||||||
|  |       } | ||||||
|  |       gOaccum.data() = gOaccum.data() + kHeadDimLatent; | ||||||
|  |     } | ||||||
|  |     auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; | ||||||
|  |     Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{}); | ||||||
|  |  | ||||||
|  |     CUTLASS_PRAGMA_UNROLL | ||||||
|  |     for(int i = 0; i < Elements; ++i) { | ||||||
|  |       gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | }  // namespace cutlass::fmha::kernel | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -0,0 +1,165 @@ | |||||||
|  | /*************************************************************************************************** | ||||||
|  |  * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights | ||||||
|  |  *reserved. SPDX-License-Identifier: BSD-3-Clause | ||||||
|  |  * | ||||||
|  |  * Redistribution and use in source and binary forms, with or without | ||||||
|  |  * modification, are permitted provided that the following conditions are met: | ||||||
|  |  * | ||||||
|  |  * 1. Redistributions of source code must retain the above copyright notice, | ||||||
|  |  *this list of conditions and the following disclaimer. | ||||||
|  |  * | ||||||
|  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | ||||||
|  |  * this list of conditions and the following disclaimer in the documentation | ||||||
|  |  * and/or other materials provided with the distribution. | ||||||
|  |  * | ||||||
|  |  * 3. Neither the name of the copyright holder nor the names of its | ||||||
|  |  * contributors may be used to endorse or promote products derived from | ||||||
|  |  * this software without specific prior written permission. | ||||||
|  |  * | ||||||
|  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||||||
|  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||||||
|  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||||||
|  |  *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||||||
|  |  *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||||||
|  |  *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||||||
|  |  *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||||||
|  |  *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||||||
|  |  *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||||||
|  |  *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE | ||||||
|  |  *POSSIBILITY OF SUCH DAMAGE. | ||||||
|  |  * | ||||||
|  |  **************************************************************************************************/ | ||||||
|  | /* | ||||||
|  |  * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 | ||||||
|  |  * by Alcanderian JieXin Liang | ||||||
|  |  */ | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "cutlass/cutlass.h" | ||||||
|  | #include "cutlass/fast_math.h" | ||||||
|  | #include "cutlass/kernel_hardware_info.h" | ||||||
|  |  | ||||||
|  | namespace cutlass::fmha::kernel { | ||||||
|  |  | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | struct Sm100MlaIndividualTileScheduler { | ||||||
|  |  | ||||||
|  |   struct Params { | ||||||
|  |     dim3 grid; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   bool valid_ = true; | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   Sm100MlaIndividualTileScheduler(Params const&) {} | ||||||
|  |  | ||||||
|  |   template<class ProblemShape, class ClusterShape> | ||||||
|  |   static Params to_underlying_arguments( | ||||||
|  |       ProblemShape const& problem_shape, KernelHardwareInfo hw_info, | ||||||
|  |       ClusterShape const& cluster_shape, int const& split_kv) { | ||||||
|  |     using namespace cute; | ||||||
|  |     dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); | ||||||
|  |     return Params{ grid }; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static dim3 get_grid_shape(Params const& params) { | ||||||
|  |     return params.grid; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   bool is_valid() { | ||||||
|  |     return valid_; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   auto get_block_coord() { | ||||||
|  |     using namespace cute; | ||||||
|  |     return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   Sm100MlaIndividualTileScheduler& operator++() { | ||||||
|  |     valid_ = false; | ||||||
|  |     return *this; | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | struct Sm100MlaPersistentTileScheduler { | ||||||
|  |  | ||||||
|  |   struct Params { | ||||||
|  |     int num_blocks; | ||||||
|  |     FastDivmod divmod_m_block; | ||||||
|  |     FastDivmod divmod_b; | ||||||
|  |     FastDivmod divmod_split_kv; | ||||||
|  |     KernelHardwareInfo hw_info; | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   int block_idx = 0; | ||||||
|  |   Params params; | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} | ||||||
|  |  | ||||||
|  |   template<class ProblemShape, class ClusterShape> | ||||||
|  |   static Params to_underlying_arguments( | ||||||
|  |       ProblemShape const& problem_shape, KernelHardwareInfo hw_info, | ||||||
|  |       ClusterShape const& cluster_shape, int const& split_kv) { | ||||||
|  |     using namespace cute; | ||||||
|  |     // Get SM count if needed, otherwise use user supplied SM count | ||||||
|  |     int sm_count = hw_info.sm_count; | ||||||
|  |     if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { | ||||||
|  |       CUTLASS_TRACE_HOST("  WARNING: Arguments do not include a valid SM count.\n" | ||||||
|  |           "  For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); | ||||||
|  |       sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); | ||||||
|  |     hw_info.sm_count = sm_count; | ||||||
|  |  | ||||||
|  |     int num_m_blocks = size<0>(cluster_shape); | ||||||
|  |     int num_blocks = num_m_blocks * get<3>(problem_shape)  /* Batch */; | ||||||
|  |     num_blocks *= split_kv; /* Maximum Split KV*/ | ||||||
|  |  | ||||||
|  |     return Params { | ||||||
|  |       num_blocks, | ||||||
|  |       { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, | ||||||
|  |       hw_info | ||||||
|  |     }; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   static dim3 get_grid_shape(Params const& params) { | ||||||
|  |     dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); | ||||||
|  |     return grid; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   bool is_valid() { | ||||||
|  |     return block_idx < params.num_blocks; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   auto get_block_coord() { | ||||||
|  |     using namespace cute; | ||||||
|  |     int block_decode = block_idx; | ||||||
|  |     int m_block, bidb, n_split_kv; | ||||||
|  |     params.divmod_m_block(block_decode, m_block, block_decode); | ||||||
|  |     params.divmod_b(block_decode, bidb, block_decode); | ||||||
|  |     params.divmod_split_kv(block_decode, n_split_kv, block_decode); | ||||||
|  |     return make_coord(m_block, _0{}, bidb, n_split_kv); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   CUTLASS_DEVICE | ||||||
|  |   Sm100MlaPersistentTileScheduler& operator++() { | ||||||
|  |     block_idx += gridDim.x; | ||||||
|  |     return *this; | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | //////////////////////////////////////////////////////////////////////////////// | ||||||
|  |  | ||||||
|  | } // namespace cutlass::fmha::kernel | ||||||
							
								
								
									
										283
									
								
								csrc/attention/mla/sm100_cutlass_mla_kernel.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										283
									
								
								csrc/attention/mla/sm100_cutlass_mla_kernel.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,283 @@ | |||||||
|  | /* | ||||||
|  | Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | Copyright 2025 SGLang Team. All Rights Reserved. | ||||||
|  |  | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | /* | ||||||
|  |  * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 | ||||||
|  |  * by Alcanderian JieXin Liang | ||||||
|  |  */ | ||||||
|  | #include "core/registration.h" | ||||||
|  |  | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include <cutlass/cutlass.h> | ||||||
|  | #include <cutlass/kernel_hardware_info.h> | ||||||
|  | #include <torch/all.h> | ||||||
|  |  | ||||||
|  | #include <cute/tensor.hpp> | ||||||
|  | #include <iostream> | ||||||
|  |  | ||||||
|  | #include "cutlass_sm100_mla/device/sm100_mla.hpp" | ||||||
|  | #include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp" | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  | #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 | ||||||
|  | void sm100_cutlass_mla_decode( | ||||||
|  |     torch::Tensor const& out, | ||||||
|  |     torch::Tensor const& q_nope, | ||||||
|  |     torch::Tensor const& q_pe, | ||||||
|  |     torch::Tensor const& kv_c_and_k_pe_cache, | ||||||
|  |     torch::Tensor const& seq_lens, | ||||||
|  |     torch::Tensor const& page_table, | ||||||
|  |     torch::Tensor const& workspace, | ||||||
|  |     int64_t num_kv_splits) { | ||||||
|  |   TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); | ||||||
|  | } | ||||||
|  | int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { | ||||||
|  |   TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"); | ||||||
|  | } | ||||||
|  | #else | ||||||
|  |  | ||||||
|  | #define CUTLASS_CHECK(status)                                                       \ | ||||||
|  |   {                                                                                 \ | ||||||
|  |     cutlass::Status error = status;                                                 \ | ||||||
|  |     TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | using namespace cute; | ||||||
|  | using namespace cutlass::fmha::kernel; | ||||||
|  |  | ||||||
|  | template <bool v> | ||||||
|  | struct IsPersistent { | ||||||
|  |   static const bool value = v; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>> | ||||||
|  | struct MlaSm100 { | ||||||
|  |   using Element = T; | ||||||
|  |   using ElementAcc = float; | ||||||
|  |   using ElementOut = T; | ||||||
|  |  | ||||||
|  |   using TileShape = Shape<_128, _128, Shape<_512, _64>>; | ||||||
|  |   using TileShapeH = cute::tuple_element_t<0, TileShape>; | ||||||
|  |   using TileShapeD = cute::tuple_element_t<2, TileShape>; | ||||||
|  |  | ||||||
|  |   // H K (D_latent D_rope) B | ||||||
|  |   using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>; | ||||||
|  |  | ||||||
|  |   using StrideQ = cute::tuple<int64_t, _1, int64_t>;  // H D B | ||||||
|  |   using StrideK = cute::tuple<int64_t, _1, int64_t>;  // K D B | ||||||
|  |   using StrideO = StrideK;                            // H D B | ||||||
|  |   using StrideLSE = cute::tuple<_1, int>;             // H B | ||||||
|  |  | ||||||
|  |   using TileScheduler = | ||||||
|  |       std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>; | ||||||
|  |  | ||||||
|  |   using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< | ||||||
|  |       TileShape, | ||||||
|  |       Element, | ||||||
|  |       ElementAcc, | ||||||
|  |       ElementOut, | ||||||
|  |       ElementAcc, | ||||||
|  |       TileScheduler, | ||||||
|  |       /*kIsCpAsync=*/!IsPaged128>; | ||||||
|  |   using Fmha = cutlass::fmha::device::MLA<FmhaKernel>; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | typename T::Fmha::Arguments args_from_options( | ||||||
|  |     at::Tensor const& out, | ||||||
|  |     at::Tensor const& q_nope, | ||||||
|  |     at::Tensor const& q_pe, | ||||||
|  |     at::Tensor const& kv_c_and_k_pe_cache, | ||||||
|  |     at::Tensor const& seq_lens, | ||||||
|  |     at::Tensor const& page_table, | ||||||
|  |     double sm_scale, | ||||||
|  |     int64_t num_kv_splits) { | ||||||
|  |   cutlass::KernelHardwareInfo hw_info; | ||||||
|  |   hw_info.device_id = q_nope.device().index(); | ||||||
|  |   hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); | ||||||
|  |  | ||||||
|  |   int batches = q_nope.sizes()[0]; | ||||||
|  |   int page_count_per_seq = page_table.sizes()[1]; | ||||||
|  |   int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; | ||||||
|  |   int page_size = kv_c_and_k_pe_cache.sizes()[1]; | ||||||
|  |   int max_seq_len = page_size * page_count_per_seq; | ||||||
|  |   using TileShapeH = typename T::TileShapeH; | ||||||
|  |   using TileShapeD = typename T::TileShapeD; | ||||||
|  |   auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); | ||||||
|  |  | ||||||
|  |   auto [H, K, D, B] = problem_shape; | ||||||
|  |   auto [D_latent, D_rope] = D; | ||||||
|  |  | ||||||
|  |   float scale = float(sm_scale); | ||||||
|  |  | ||||||
|  |   using StrideQ = typename T::StrideQ; | ||||||
|  |   using StrideK = typename T::StrideK; | ||||||
|  |   using StrideO = typename T::StrideO; | ||||||
|  |   using StrideLSE = typename T::StrideLSE; | ||||||
|  |  | ||||||
|  |   StrideQ stride_Q_nope = cute::make_tuple( | ||||||
|  |       static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0))); | ||||||
|  |   StrideQ stride_Q_pe = cute::make_tuple( | ||||||
|  |       static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0))); | ||||||
|  |  | ||||||
|  |   StrideK stride_C = cute::make_tuple( | ||||||
|  |       static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope))); | ||||||
|  |   StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); | ||||||
|  |   StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H); | ||||||
|  |   StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent)); | ||||||
|  |  | ||||||
|  |   using Element = typename T::Element; | ||||||
|  |   using ElementOut = typename T::ElementOut; | ||||||
|  |   using ElementAcc = typename T::ElementAcc; | ||||||
|  |   auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr()); | ||||||
|  |   auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr()); | ||||||
|  |   auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr()); | ||||||
|  |   typename T::Fmha::Arguments arguments{ | ||||||
|  |       problem_shape, | ||||||
|  |       {scale, | ||||||
|  |        Q_nope_ptr, | ||||||
|  |        stride_Q_nope, | ||||||
|  |        Q_pe_ptr, | ||||||
|  |        stride_Q_pe, | ||||||
|  |        C_ptr, | ||||||
|  |        stride_C, | ||||||
|  |        C_ptr + D_latent, | ||||||
|  |        stride_C, | ||||||
|  |        static_cast<int*>(seq_lens.data_ptr()), | ||||||
|  |        static_cast<int*>(page_table.data_ptr()), | ||||||
|  |        stride_PT, | ||||||
|  |        page_count_total, | ||||||
|  |        page_size}, | ||||||
|  |       {static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE}, | ||||||
|  |       hw_info, | ||||||
|  |       // TODO(trevor-m): Change split_kv back to -1 when | ||||||
|  |       // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will | ||||||
|  |       // perform worse with larger context length and smaller batch sizes. | ||||||
|  |       num_kv_splits, // split_kv | ||||||
|  |       nullptr,       // is_var_split_kv | ||||||
|  |   }; | ||||||
|  |   // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute | ||||||
|  |   // split_kv automatically based on batch size and sequence length to balance | ||||||
|  |   // workload across available SMs. Consider using var_split_kv for manual | ||||||
|  |   // control if needed. | ||||||
|  |   T::Fmha::set_split_kv(arguments); | ||||||
|  |   return arguments; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename Element, bool IsPaged128, typename PersistenceOption> | ||||||
|  | void runMla( | ||||||
|  |     at::Tensor const& out, | ||||||
|  |     at::Tensor const& q_nope, | ||||||
|  |     at::Tensor const& q_pe, | ||||||
|  |     at::Tensor const& kv_c_and_k_pe_cache, | ||||||
|  |     at::Tensor const& seq_lens, | ||||||
|  |     at::Tensor const& page_table, | ||||||
|  |     at::Tensor const& workspace, | ||||||
|  |     double sm_scale, | ||||||
|  |     int64_t num_kv_splits, | ||||||
|  |     cudaStream_t stream) { | ||||||
|  |   using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>; | ||||||
|  |   typename MlaSm100Type::Fmha fmha; | ||||||
|  |   auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); | ||||||
|  |  | ||||||
|  |   CUTLASS_CHECK(fmha.can_implement(arguments)); | ||||||
|  |  | ||||||
|  |   CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); | ||||||
|  |  | ||||||
|  |   CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define DISPATCH_BOOL(expr, const_expr, ...) \ | ||||||
|  |   [&]() -> bool {                            \ | ||||||
|  |     if (expr) {                              \ | ||||||
|  |       constexpr bool const_expr = true;      \ | ||||||
|  |       return __VA_ARGS__();                  \ | ||||||
|  |     } else {                                 \ | ||||||
|  |       constexpr bool const_expr = false;     \ | ||||||
|  |       return __VA_ARGS__();                  \ | ||||||
|  |     }                                        \ | ||||||
|  |   }() | ||||||
|  |  | ||||||
|  | void sm100_cutlass_mla_decode( | ||||||
|  |     torch::Tensor const& out, | ||||||
|  |     torch::Tensor const& q_nope, | ||||||
|  |     torch::Tensor const& q_pe, | ||||||
|  |     torch::Tensor const& kv_c_and_k_pe_cache, | ||||||
|  |     torch::Tensor const& seq_lens, | ||||||
|  |     torch::Tensor const& page_table, | ||||||
|  |     torch::Tensor const& workspace, | ||||||
|  |     double sm_scale, | ||||||
|  |     int64_t num_kv_splits) { | ||||||
|  |   auto in_dtype = q_nope.dtype(); | ||||||
|  |   at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; | ||||||
|  |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); | ||||||
|  |   const int page_size = kv_c_and_k_pe_cache.sizes()[1]; | ||||||
|  |    | ||||||
|  |   // NOTE(alcanderian): IsPersistent has bug with manual split_kv. | ||||||
|  |   // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) | ||||||
|  |   // Maybe per batch split kv will fix this. | ||||||
|  |   DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { | ||||||
|  |     DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { | ||||||
|  |       if (in_dtype == at::ScalarType::Half) { | ||||||
|  |         runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>( | ||||||
|  |           out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); | ||||||
|  |       } else if (in_dtype == at::ScalarType::BFloat16) { | ||||||
|  |         runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>( | ||||||
|  |           out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); | ||||||
|  |       } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { | ||||||
|  |         runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>( | ||||||
|  |           out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); | ||||||
|  |       } else { | ||||||
|  |         TORCH_CHECK(false, "Unsupported input data type of MLA"); | ||||||
|  |       } | ||||||
|  |       return true; | ||||||
|  |     }); | ||||||
|  |     return true; | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { | ||||||
|  |   // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) | ||||||
|  |   // which are float, so Element type here doesn't matter. | ||||||
|  |   using MlaSm100Type = MlaSm100<cutlass::half_t, true>; | ||||||
|  |  | ||||||
|  |   // Get split kv. Requires problem shape and sm_count only. | ||||||
|  |   typename MlaSm100Type::Fmha::Arguments arguments; | ||||||
|  |   using TileShapeH = typename MlaSm100Type::TileShapeH; | ||||||
|  |   using TileShapeD = typename MlaSm100Type::TileShapeD; | ||||||
|  |   arguments.problem_shape = | ||||||
|  |       cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches)); | ||||||
|  |   // Assumes device 0 when getting sm_count. | ||||||
|  |   arguments.hw_info.sm_count = | ||||||
|  |       sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; | ||||||
|  |   arguments.split_kv = num_kv_splits; | ||||||
|  |   MlaSm100Type::Fmha::set_split_kv(arguments); | ||||||
|  |  | ||||||
|  |   return MlaSm100Type::Fmha::get_workspace_size(arguments); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { | ||||||
|  |   m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) { | ||||||
|  |   m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // clang-format on | ||||||
| @ -18,12 +18,7 @@ | |||||||
|  */ |  */ | ||||||
|  |  | ||||||
| #include "attention_kernels.cuh" | #include "attention_kernels.cuh" | ||||||
|  | #include "cuda_compat.h" | ||||||
| #ifndef USE_ROCM |  | ||||||
|   #define WARP_SIZE 32 |  | ||||||
| #else |  | ||||||
|   #define WARP_SIZE warpSize |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||||
| @ -65,9 +60,6 @@ void paged_attention_v1_launcher( | |||||||
|   int kv_block_stride = key_cache.stride(0); |   int kv_block_stride = key_cache.stride(0); | ||||||
|   int kv_head_stride = key_cache.stride(1); |   int kv_head_stride = key_cache.stride(1); | ||||||
|  |  | ||||||
|   [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); |  | ||||||
|   assert(head_size % thread_group_size == 0); |  | ||||||
|  |  | ||||||
|   // NOTE: alibi_slopes is optional. |   // NOTE: alibi_slopes is optional. | ||||||
|   const float* alibi_slopes_ptr = |   const float* alibi_slopes_ptr = | ||||||
|       alibi_slopes |       alibi_slopes | ||||||
| @ -190,7 +182,6 @@ void paged_attention_v1( | |||||||
|                              CALL_V1_LAUNCHER_BLOCK_SIZE) |                              CALL_V1_LAUNCHER_BLOCK_SIZE) | ||||||
| } | } | ||||||
|  |  | ||||||
| #undef WARP_SIZE |  | ||||||
| #undef MAX | #undef MAX | ||||||
| #undef MIN | #undef MIN | ||||||
| #undef DIVIDE_ROUND_UP | #undef DIVIDE_ROUND_UP | ||||||
|  | |||||||
| @ -18,12 +18,7 @@ | |||||||
|  */ |  */ | ||||||
|  |  | ||||||
| #include "attention_kernels.cuh" | #include "attention_kernels.cuh" | ||||||
|  | #include "cuda_compat.h" | ||||||
| #ifndef USE_ROCM |  | ||||||
|   #define WARP_SIZE 32 |  | ||||||
| #else |  | ||||||
|   #define WARP_SIZE warpSize |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||||
| @ -66,9 +61,6 @@ void paged_attention_v2_launcher( | |||||||
|   int kv_block_stride = key_cache.stride(0); |   int kv_block_stride = key_cache.stride(0); | ||||||
|   int kv_head_stride = key_cache.stride(1); |   int kv_head_stride = key_cache.stride(1); | ||||||
|  |  | ||||||
|   [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); |  | ||||||
|   assert(head_size % thread_group_size == 0); |  | ||||||
|  |  | ||||||
|   // NOTE: alibi_slopes is optional. |   // NOTE: alibi_slopes is optional. | ||||||
|   const float* alibi_slopes_ptr = |   const float* alibi_slopes_ptr = | ||||||
|       alibi_slopes |       alibi_slopes | ||||||
| @ -200,7 +192,6 @@ void paged_attention_v2( | |||||||
|                              CALL_V2_LAUNCHER_BLOCK_SIZE) |                              CALL_V2_LAUNCHER_BLOCK_SIZE) | ||||||
| } | } | ||||||
|  |  | ||||||
| #undef WARP_SIZE |  | ||||||
| #undef MAX | #undef MAX | ||||||
| #undef MIN | #undef MIN | ||||||
| #undef DIVIDE_ROUND_UP | #undef DIVIDE_ROUND_UP | ||||||
|  | |||||||
| @ -137,8 +137,8 @@ FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size, | |||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
| FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, | FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data, | ||||||
|                                         const int size) { |                                          const int size) { | ||||||
|   T max = max_data[0]; |   T max = max_data[0]; | ||||||
|   for (int i = 1; i < size; ++i) { |   for (int i = 1; i < size; ++i) { | ||||||
|     max = max >= max_data[i] ? max : max_data[i]; |     max = max >= max_data[i] ? max : max_data[i]; | ||||||
| @ -634,7 +634,7 @@ struct paged_attention_v2_impl { | |||||||
|  |  | ||||||
|         if (partition_num == 1) continue; |         if (partition_num == 1) continue; | ||||||
|  |  | ||||||
|         reducePartitonSoftmax( |         reducePartitionSoftmax( | ||||||
|             max_logits + seq_idx * num_heads * max_num_partitions + |             max_logits + seq_idx * num_heads * max_num_partitions + | ||||||
|                 head_idx * max_num_partitions, |                 head_idx * max_num_partitions, | ||||||
|             exp_sums + seq_idx * num_heads * max_num_partitions + |             exp_sums + seq_idx * num_heads * max_num_partitions + | ||||||
|  | |||||||
| @ -33,6 +33,8 @@ namespace vec_op { | |||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #define FORCE_INLINE __attribute__((always_inline)) inline | #define FORCE_INLINE __attribute__((always_inline)) inline | ||||||
|  | // Number of elements in single ASIMD vector of given Datatype | ||||||
|  | #define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0])) | ||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
| template <typename T, T... indexes, typename F> | template <typename T, T... indexes, typename F> | ||||||
| @ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> { | |||||||
|   } |   } | ||||||
|  |  | ||||||
|   void save(void* ptr, const int elem_num) const { |   void save(void* ptr, const int elem_num) const { | ||||||
|     int full_blocks = elem_num / 8; |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|     int remainder = elem_num % 8; |     int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |  | ||||||
|     if (full_blocks > 0) { |     if (full_blocks > 0) { | ||||||
|       vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); |       vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); | ||||||
| @ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> { | |||||||
|              vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {}; |              vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {}; | ||||||
|  |  | ||||||
|   void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; }; |   void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; }; | ||||||
|  |   void save(void* ptr, const int elem_num) const { | ||||||
|  |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     for (int i = 0; i < full_blocks; i++) | ||||||
|  |       vst1q_bf16( | ||||||
|  |           reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, | ||||||
|  |           reg.val[i]); | ||||||
|  |     if (remainder > 0) { | ||||||
|  |       bfloat16x8_t temp = reg.val[full_blocks]; | ||||||
|  |       bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8; | ||||||
|  |       if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0); | ||||||
|  |       if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1); | ||||||
|  |       if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2); | ||||||
|  |       if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3); | ||||||
|  |       if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4); | ||||||
|  |       if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5); | ||||||
|  |       if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct BF16Vec32 : public Vec<BF16Vec32> { | struct BF16Vec32 : public Vec<BF16Vec32> { | ||||||
| @ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> { | |||||||
|       : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}; |       : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}; | ||||||
|  |  | ||||||
|   void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; }; |   void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; }; | ||||||
|  |   void save(void* ptr, const int elem_num) const { | ||||||
|  |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     for (int i = 0; i < full_blocks; i++) | ||||||
|  |       vst1q_bf16( | ||||||
|  |           reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, | ||||||
|  |           reg.val[i]); | ||||||
|  |     if (remainder > 0) { | ||||||
|  |       bfloat16x8_t temp = reg.val[full_blocks]; | ||||||
|  |       bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8; | ||||||
|  |       base[0] = vgetq_lane_bf16(temp, 0); | ||||||
|  |       if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1); | ||||||
|  |       if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2); | ||||||
|  |       if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3); | ||||||
|  |       if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4); | ||||||
|  |       if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5); | ||||||
|  |       if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
| }; | }; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| @ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> { | |||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | struct INT32Vec16 : public Vec<INT32Vec16> { | ||||||
|  |   constexpr static int VEC_ELEM_NUM = 16; | ||||||
|  |   union AliasReg { | ||||||
|  |     int32x4x4_t reg; | ||||||
|  |     int32_t values[VEC_ELEM_NUM]; | ||||||
|  |   }; | ||||||
|  |   int32x4x4_t reg; | ||||||
|  |  | ||||||
|  |   explicit INT32Vec16(const void* ptr) { | ||||||
|  |     reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr)); | ||||||
|  |     reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4); | ||||||
|  |     reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8); | ||||||
|  |     reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void save(int32_t* ptr) const { | ||||||
|  |     vst1q_s32(ptr, reg.val[0]); | ||||||
|  |     vst1q_s32(ptr + 4, reg.val[1]); | ||||||
|  |     vst1q_s32(ptr + 8, reg.val[2]); | ||||||
|  |     vst1q_s32(ptr + 12, reg.val[3]); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   void save(int32_t* ptr, const int elem_num) const { | ||||||
|  |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < full_blocks; i++) | ||||||
|  |       vst1q_s32( | ||||||
|  |           reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, | ||||||
|  |           reg.val[i]); | ||||||
|  |  | ||||||
|  |     if (remainder > 0) { | ||||||
|  |       int32x4_t temp = reg.val[full_blocks]; | ||||||
|  |       int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4; | ||||||
|  |       if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0); | ||||||
|  |       if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1); | ||||||
|  |       if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2); | ||||||
|  |       if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
| struct FP32Vec16 : public Vec<FP32Vec16> { | struct FP32Vec16 : public Vec<FP32Vec16> { | ||||||
|   constexpr static int VEC_ELEM_NUM = 16; |   constexpr static int VEC_ELEM_NUM = 16; | ||||||
|   union AliasReg { |   union AliasReg { | ||||||
| @ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> { | |||||||
|     reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); |     reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); | ||||||
|     reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); |     reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); | ||||||
|   }; |   }; | ||||||
|  |   explicit FP32Vec16(const INT32Vec16& v) { | ||||||
|  |     reg.val[0] = vcvtq_f32_s32(v.reg.val[0]); | ||||||
|  |     reg.val[1] = vcvtq_f32_s32(v.reg.val[1]); | ||||||
|  |     reg.val[2] = vcvtq_f32_s32(v.reg.val[2]); | ||||||
|  |     reg.val[3] = vcvtq_f32_s32(v.reg.val[3]); | ||||||
|  |   }; | ||||||
|   FP32Vec16 operator+(const FP32Vec16& b) const { |   FP32Vec16 operator+(const FP32Vec16& b) const { | ||||||
|     return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]), |     return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]), | ||||||
|                                     vaddq_f32(reg.val[1], b.reg.val[1]), |                                     vaddq_f32(reg.val[1], b.reg.val[1]), | ||||||
| @ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> { | |||||||
|                                     vdivq_f32(reg.val[3], b.reg.val[3])})); |                                     vdivq_f32(reg.val[3], b.reg.val[3])})); | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { | ||||||
|  |     return FP32Vec16(float32x4x4_t( | ||||||
|  |         {vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])), | ||||||
|  |          vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])), | ||||||
|  |          vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])), | ||||||
|  |          vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))})); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   FP32Vec16 max(const FP32Vec16& b) const { | ||||||
|  |     return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]), | ||||||
|  |                                     vmaxq_f32(b.reg.val[1], reg.val[1]), | ||||||
|  |                                     vmaxq_f32(b.reg.val[2], reg.val[2]), | ||||||
|  |                                     vmaxq_f32(b.reg.val[3], reg.val[3])})); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { | ||||||
|  |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     float32x4x4_t temp; | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < full_blocks; i++) | ||||||
|  |       temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]); | ||||||
|  |  | ||||||
|  |     if (remainder > 0) { | ||||||
|  |       float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0), | ||||||
|  |                              vgetq_lane_f32(b.reg.val[full_blocks], 0)); | ||||||
|  |       temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0); | ||||||
|  |     } | ||||||
|  |     if (remainder > 1) { | ||||||
|  |       float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1), | ||||||
|  |                              vgetq_lane_f32(b.reg.val[full_blocks], 1)); | ||||||
|  |       temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1); | ||||||
|  |     } | ||||||
|  |     if (remainder > 2) { | ||||||
|  |       float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2), | ||||||
|  |                              vgetq_lane_f32(b.reg.val[full_blocks], 2)); | ||||||
|  |       temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2); | ||||||
|  |     } | ||||||
|  |     return FP32Vec16(temp); | ||||||
|  |   }; | ||||||
|  |  | ||||||
|  |   FP32Vec16 min(const FP32Vec16& b) const { | ||||||
|  |     return FP32Vec16(float32x4x4_t({ | ||||||
|  |         vminq_f32(b.reg.val[0], reg.val[0]), | ||||||
|  |         vminq_f32(b.reg.val[1], reg.val[1]), | ||||||
|  |         vminq_f32(b.reg.val[2], reg.val[2]), | ||||||
|  |         vminq_f32(b.reg.val[3], reg.val[3]), | ||||||
|  |     })); | ||||||
|  |   }; | ||||||
|  |   FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { | ||||||
|  |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     float32x4x4_t temp; | ||||||
|  |     for (int i = 0; i < full_blocks; i++) | ||||||
|  |       temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]); | ||||||
|  |  | ||||||
|  |     if (remainder > 0) { | ||||||
|  |       float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0), | ||||||
|  |                              vgetq_lane_f32(b.reg.val[full_blocks], 0)); | ||||||
|  |       temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0); | ||||||
|  |     } | ||||||
|  |     if (remainder > 1) { | ||||||
|  |       float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1), | ||||||
|  |                              vgetq_lane_f32(b.reg.val[full_blocks], 1)); | ||||||
|  |       temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1); | ||||||
|  |     } | ||||||
|  |     if (remainder > 2) { | ||||||
|  |       float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2), | ||||||
|  |                              vgetq_lane_f32(b.reg.val[full_blocks], 2)); | ||||||
|  |       temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     return FP32Vec16(temp); | ||||||
|  |   }; | ||||||
|  |   FP32Vec16 abs() const { | ||||||
|  |     return FP32Vec16( | ||||||
|  |         float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]), | ||||||
|  |                        vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])})); | ||||||
|  |   } | ||||||
|   float reduce_sum() const { |   float reduce_sum() const { | ||||||
|     AliasReg ar; |     AliasReg ar; | ||||||
|     ar.reg = reg; |     ar.reg = reg; | ||||||
| @ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> { | |||||||
|     return answer; |     return answer; | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   float reduce_max() const { | ||||||
|  |     AliasReg ar; | ||||||
|  |     ar.reg = reg; | ||||||
|  |     float max_v = std::numeric_limits<float>::lowest(); | ||||||
|  |     unroll_loop<int, VEC_ELEM_NUM>( | ||||||
|  |         [&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); }); | ||||||
|  |     return max_v; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   float reduce_min() const { | ||||||
|  |     AliasReg ar; | ||||||
|  |     ar.reg = reg; | ||||||
|  |     float min_v = std::numeric_limits<float>::max(); | ||||||
|  |     unroll_loop<int, VEC_ELEM_NUM>( | ||||||
|  |         [&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); }); | ||||||
|  |     return min_v; | ||||||
|  |   } | ||||||
|  |  | ||||||
|   template <int group_size> |   template <int group_size> | ||||||
|   float reduce_sub_sum(int idx) { |   float reduce_sub_sum(int idx) { | ||||||
|     static_assert(VEC_ELEM_NUM % group_size == 0); |     static_assert(VEC_ELEM_NUM % group_size == 0); | ||||||
| @ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> { | |||||||
|     vst1q_f32(ptr + 8, reg.val[2]); |     vst1q_f32(ptr + 8, reg.val[2]); | ||||||
|     vst1q_f32(ptr + 12, reg.val[3]); |     vst1q_f32(ptr + 12, reg.val[3]); | ||||||
|   }; |   }; | ||||||
|  |  | ||||||
|  |   void save(float* ptr, const int elem_num) const { | ||||||
|  |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |     int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < full_blocks; i++) | ||||||
|  |       vst1q_f32( | ||||||
|  |           reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, | ||||||
|  |           reg.val[i]); | ||||||
|  |  | ||||||
|  |     if (remainder > 0) { | ||||||
|  |       float32x4_t temp = reg.val[full_blocks]; | ||||||
|  |       float* base = reinterpret_cast<float32_t*>(ptr) + | ||||||
|  |                     full_blocks * NUM_ELEMENTS_REG(reg.val[0]); | ||||||
|  |       if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0); | ||||||
|  |       if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1); | ||||||
|  |       if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct INT8Vec16 : public Vec<INT8Vec16> { | ||||||
|  |   constexpr static int VEC_ELEM_NUM = 16; | ||||||
|  |   union AliasReg { | ||||||
|  |     int8x16_t reg; | ||||||
|  |     int8_t values[VEC_ELEM_NUM]; | ||||||
|  |   }; | ||||||
|  |   int8x16_t reg; | ||||||
|  |  | ||||||
|  |   explicit INT8Vec16(const FP32Vec16& vec) { | ||||||
|  |     // Convert each 128-bit float32 vector to int32 | ||||||
|  |     int32x4_t part0 = | ||||||
|  |         vcvtq_s32_f32(vec.reg.val[0]);  // Convert first 128-bit block | ||||||
|  |     int32x4_t part1 = | ||||||
|  |         vcvtq_s32_f32(vec.reg.val[1]);  // Convert second 128-bit block | ||||||
|  |     int32x4_t part2 = | ||||||
|  |         vcvtq_s32_f32(vec.reg.val[2]);  // Convert third 128-bit block | ||||||
|  |     int32x4_t part3 = | ||||||
|  |         vcvtq_s32_f32(vec.reg.val[3]);  // Convert fourth 128-bit block | ||||||
|  |  | ||||||
|  |     // Narrow each 32-bit vector to 8 bits and combine | ||||||
|  |     int8x8_t lower = | ||||||
|  |         vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1))); | ||||||
|  |     int8x8_t upper = | ||||||
|  |         vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3))); | ||||||
|  |     reg = vcombine_s8(lower, upper);  // Combine to form a single 128-bit vector | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void save(int8_t* ptr) const { vst1q_s8(ptr, reg); }; | ||||||
|  |  | ||||||
|  |   void save(int8_t* ptr, const int elem_num) const { | ||||||
|  |     int full_blocks = elem_num / NUM_ELEMENTS_REG(reg); | ||||||
|  |     int remainder = elem_num % NUM_ELEMENTS_REG(reg); | ||||||
|  |  | ||||||
|  |     for (int i = 0; i < full_blocks; i++) | ||||||
|  |       vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg); | ||||||
|  |     if (remainder > 0) { | ||||||
|  |       int8x16_t temp = reg; | ||||||
|  |       int8_t* base = | ||||||
|  |           reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg); | ||||||
|  |       if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0); | ||||||
|  |       if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1); | ||||||
|  |       if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2); | ||||||
|  |       if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3); | ||||||
|  |       if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4); | ||||||
|  |       if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5); | ||||||
|  |       if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6); | ||||||
|  |       if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7); | ||||||
|  |       if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8); | ||||||
|  |       if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9); | ||||||
|  |       if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10); | ||||||
|  |       if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11); | ||||||
|  |       if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12); | ||||||
|  |       if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13); | ||||||
|  |       if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14); | ||||||
|  |     } | ||||||
|  |   }; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| template <typename T> | template <typename T> | ||||||
|  | |||||||
| @ -83,7 +83,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> { | |||||||
|   explicit FP16Vec16(const void* ptr) |   explicit FP16Vec16(const void* ptr) | ||||||
|       : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} |       : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} | ||||||
|  |  | ||||||
|   // non-temproal load |   // non-temporal load | ||||||
|   explicit FP16Vec16(bool, void* ptr) |   explicit FP16Vec16(bool, void* ptr) | ||||||
|       : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} |       : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} | ||||||
|  |  | ||||||
| @ -120,7 +120,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> { | |||||||
|   explicit BF16Vec16(const void* ptr) |   explicit BF16Vec16(const void* ptr) | ||||||
|       : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} |       : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} | ||||||
|  |  | ||||||
|   // non-temproal load |   // non-temporal load | ||||||
|   explicit BF16Vec16(bool, void* ptr) |   explicit BF16Vec16(bool, void* ptr) | ||||||
|       : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} |       : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} | ||||||
|  |  | ||||||
| @ -327,7 +327,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> { | |||||||
|   // normal load |   // normal load | ||||||
|   explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} |   explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} | ||||||
|  |  | ||||||
|   // non-temproal load |   // non-temporal load | ||||||
|   explicit FP32Vec16(bool, void* ptr) |   explicit FP32Vec16(bool, void* ptr) | ||||||
|       : reg((__m512)_mm512_stream_load_si512(ptr)) {} |       : reg((__m512)_mm512_stream_load_si512(ptr)) {} | ||||||
|  |  | ||||||
| @ -576,7 +576,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> { | |||||||
|   // normal load |   // normal load | ||||||
|   explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {} |   explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {} | ||||||
|  |  | ||||||
|   // non-temproal load |   // non-temporal load | ||||||
|   explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {} |   explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {} | ||||||
|  |  | ||||||
|   void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); } |   void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); } | ||||||
| @ -587,7 +587,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> { | |||||||
|     _mm512_mask_storeu_epi8(ptr, mask, reg); |     _mm512_mask_storeu_epi8(ptr, mask, reg); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   // non-temproal save |   // non-temporal save | ||||||
|   void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); } |   void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); } | ||||||
| }; | }; | ||||||
| #endif | #endif | ||||||
|  | |||||||
| @ -57,6 +57,7 @@ class DNNLPrimitiveHelper { | |||||||
|   // Note: Due to the limitation of oneDNN |   // Note: Due to the limitation of oneDNN | ||||||
|   // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is |   // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is | ||||||
|   // not supported. |   // not supported. | ||||||
|  |  | ||||||
|   template <typename OutputT, typename BiasT> |   template <typename OutputT, typename BiasT> | ||||||
|   static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, |   static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, | ||||||
|                             const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, |                             const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, | ||||||
| @ -90,6 +91,27 @@ class DNNLPrimitiveHelper { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     dnnl::matmul::primitive_desc matmul_pd; |     dnnl::matmul::primitive_desc matmul_pd; | ||||||
|  | // Create memory descriptors with format_tag::any for the primitive. This | ||||||
|  | // enables the matmul primitive to choose memory layouts for an | ||||||
|  | // optimized primitive implementation, and these layouts may differ from the | ||||||
|  | // ones provided by the user. | ||||||
|  | #ifdef __aarch64__ | ||||||
|  |     auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, | ||||||
|  |                                          dnnl::memory::format_tag::any); | ||||||
|  |     auto mat_weights_md = dnnl::memory::desc( | ||||||
|  |         {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); | ||||||
|  |     auto mat_dst_md = | ||||||
|  |         dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); | ||||||
|  |     if (bias) { | ||||||
|  |       dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); | ||||||
|  |       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, | ||||||
|  |                                                mat_weights_md, bias_md, | ||||||
|  |                                                mat_dst_md, attr); | ||||||
|  |     } else { | ||||||
|  |       matmul_pd = dnnl::matmul::primitive_desc( | ||||||
|  |           default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); | ||||||
|  |     } | ||||||
|  | #else | ||||||
|     if (bias) { |     if (bias) { | ||||||
|       dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); |       dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); | ||||||
|       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, |       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, | ||||||
| @ -98,6 +120,7 @@ class DNNLPrimitiveHelper { | |||||||
|       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, |       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, | ||||||
|                                                c_md, attr); |                                                c_md, attr); | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
|     dnnl::matmul matmul(matmul_pd); |     dnnl::matmul matmul(matmul_pd); | ||||||
|  |  | ||||||
|     auto& engine = default_engine(); |     auto& engine = default_engine(); | ||||||
| @ -111,24 +134,34 @@ class DNNLPrimitiveHelper { | |||||||
|                             (void*)b_scales); |                             (void*)b_scales); | ||||||
|  |  | ||||||
|     auto& stream = default_stream(); |     auto& stream = default_stream(); | ||||||
|  |  | ||||||
|  |     auto mat_src_mem = a_m; | ||||||
|  |     auto mat_weights_mem = b_m; | ||||||
|  |     auto mat_dst_mem = c_m; | ||||||
|  | #ifdef __aarch64__ | ||||||
|  |     if (matmul_pd.weights_desc() != b_m.get_desc()) { | ||||||
|  |       mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); | ||||||
|  |       dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|     if constexpr (InputNoScale) { |     if constexpr (InputNoScale) { | ||||||
|       if (bias) { |       if (bias) { | ||||||
|         dnnl::memory::desc bias_md({N}, BiasType, {1}); |         dnnl::memory::desc bias_md({N}, BiasType, {1}); | ||||||
|         dnnl::memory bias_m(bias_md, engine, (void*)bias); |         dnnl::memory bias_m(bias_md, engine, (void*)bias); | ||||||
|         matmul.execute( |         matmul.execute( | ||||||
|             stream, { |             stream, { | ||||||
|                         {DNNL_ARG_SRC, a_m}, |                         {DNNL_ARG_SRC, mat_src_mem}, | ||||||
|                         {DNNL_ARG_WEIGHTS, b_m}, |                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||||
|                         {DNNL_ARG_BIAS, bias_m}, |                         {DNNL_ARG_BIAS, bias_m}, | ||||||
|                         {DNNL_ARG_DST, c_m}, |                         {DNNL_ARG_DST, mat_dst_mem}, | ||||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, |                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||||
|                     }); |                     }); | ||||||
|       } else { |       } else { | ||||||
|         matmul.execute( |         matmul.execute( | ||||||
|             stream, { |             stream, { | ||||||
|                         {DNNL_ARG_SRC, a_m}, |                         {DNNL_ARG_SRC, mat_src_mem}, | ||||||
|                         {DNNL_ARG_WEIGHTS, b_m}, |                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||||
|                         {DNNL_ARG_DST, c_m}, |                         {DNNL_ARG_DST, mat_dst_mem}, | ||||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, |                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||||
|                     }); |                     }); | ||||||
|       } |       } | ||||||
| @ -138,19 +171,19 @@ class DNNLPrimitiveHelper { | |||||||
|         dnnl::memory bias_m(bias_md, engine, (void*)bias); |         dnnl::memory bias_m(bias_md, engine, (void*)bias); | ||||||
|         matmul.execute( |         matmul.execute( | ||||||
|             stream, { |             stream, { | ||||||
|                         {DNNL_ARG_SRC, a_m}, |                         {DNNL_ARG_SRC, mat_src_mem}, | ||||||
|                         {DNNL_ARG_WEIGHTS, b_m}, |                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||||
|                         {DNNL_ARG_BIAS, bias_m}, |                         {DNNL_ARG_BIAS, bias_m}, | ||||||
|                         {DNNL_ARG_DST, c_m}, |                         {DNNL_ARG_DST, mat_dst_mem}, | ||||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, |                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, | ||||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, |                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||||
|                     }); |                     }); | ||||||
|       } else { |       } else { | ||||||
|         matmul.execute( |         matmul.execute( | ||||||
|             stream, { |             stream, { | ||||||
|                         {DNNL_ARG_SRC, a_m}, |                         {DNNL_ARG_SRC, mat_src_mem}, | ||||||
|                         {DNNL_ARG_WEIGHTS, b_m}, |                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||||
|                         {DNNL_ARG_DST, c_m}, |                         {DNNL_ARG_DST, mat_dst_mem}, | ||||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, |                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, | ||||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, |                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||||
|                     }); |                     }); | ||||||
| @ -170,5 +203,4 @@ class DNNLPrimitiveHelper { | |||||||
|     return stream; |     return stream; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| #endif | #endif | ||||||
|  | |||||||
| @ -36,7 +36,7 @@ struct KernelVecType<c10::Half> { | |||||||
|   using cvt_vec_type = vec_op::FP32Vec16; |   using cvt_vec_type = vec_op::FP32Vec16; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| #ifdef __AVX512F__ | #if defined(__AVX512F__) || defined(__aarch64__) | ||||||
| template <bool AZP, typename scalar_t> | template <bool AZP, typename scalar_t> | ||||||
| void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||||
|                                    const float* scale, const int32_t* azp, |                                    const float* scale, const int32_t* azp, | ||||||
| @ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | |||||||
|                                    const float* scale, const int32_t* azp, |                                    const float* scale, const int32_t* azp, | ||||||
|                                    const int num_tokens, |                                    const int num_tokens, | ||||||
|                                    const int hidden_size) { |                                    const int hidden_size) { | ||||||
|   TORCH_CHECK( |   TORCH_CHECK(false, | ||||||
|       false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") |               "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " | ||||||
|  |               "support.") | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| @ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | |||||||
|                                     float* scale, int32_t* azp, |                                     float* scale, int32_t* azp, | ||||||
|                                     const int num_tokens, |                                     const int num_tokens, | ||||||
|                                     const int hidden_size) { |                                     const int hidden_size) { | ||||||
|   TORCH_CHECK( |   TORCH_CHECK(false, | ||||||
|       false, |               "dynamic_scaled_int8_quant_impl requires " | ||||||
|       "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") |               "AVX512/powerpc64/AArch64 support.") | ||||||
| } | } | ||||||
|  |  | ||||||
| template <bool PerChannel, typename scalar_t> | template <bool PerChannel, typename scalar_t> | ||||||
| @ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output, | |||||||
|                            const float a_scale, const float* b_scale, |                            const float a_scale, const float* b_scale, | ||||||
|                            const int32_t* azp_with_adj, const int num_tokens, |                            const int32_t* azp_with_adj, const int num_tokens, | ||||||
|                            const int hidden_size) { |                            const int hidden_size) { | ||||||
|   TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") |   TORCH_CHECK( | ||||||
|  |       false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| @ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, | |||||||
|                             const int32_t* azp, const int32_t* azp_with_adj, |                             const int32_t* azp, const int32_t* azp_with_adj, | ||||||
|                             const scalar_t* bias, const int num_tokens, |                             const scalar_t* bias, const int num_tokens, | ||||||
|                             const int hidden_size) { |                             const int hidden_size) { | ||||||
|   TORCH_CHECK(false, |   TORCH_CHECK( | ||||||
|               "dynamic_quant_epilogue requires AVX512/powerpc64 support.") |       false, | ||||||
|  |       "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
| }  // namespace | }  // namespace | ||||||
|  | |||||||
							
								
								
									
										238
									
								
								csrc/cpu/sgl-kernels/common.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								csrc/cpu/sgl-kernels/common.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,238 @@ | |||||||
|  | // Adapted from | ||||||
|  | // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||||
|  |  | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/ATen.h> | ||||||
|  | #include <ATen/Parallel.h> | ||||||
|  | #include <ATen/record_function.h> | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | #if defined(_OPENMP) | ||||||
|  | #include <omp.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | // dispatch bool | ||||||
|  | #define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...)                                 \ | ||||||
|  |   [&] {                                                                          \ | ||||||
|  |     if (BOOL_V) {                                                                \ | ||||||
|  |       constexpr bool BOOL_NAME = true;                                           \ | ||||||
|  |       return __VA_ARGS__();                                                      \ | ||||||
|  |     } else {                                                                     \ | ||||||
|  |       constexpr bool BOOL_NAME = false;                                          \ | ||||||
|  |       return __VA_ARGS__();                                                      \ | ||||||
|  |     }                                                                            \ | ||||||
|  |   }() | ||||||
|  |  | ||||||
|  | // dispatch: bfloat16, float16, int8_t, fp8_e4m3 | ||||||
|  | #define CPU_DISPATCH_PACKED_TYPES(TYPE, ...)                                    \ | ||||||
|  |   [&] {                                                                         \ | ||||||
|  |     switch (TYPE) {                                                             \ | ||||||
|  |       case at::ScalarType::BFloat16 : {                                         \ | ||||||
|  |         using packed_t = at::BFloat16;                                          \ | ||||||
|  |         return __VA_ARGS__();                                                   \ | ||||||
|  |       }                                                                         \ | ||||||
|  |       case at::ScalarType::Half: {                                              \ | ||||||
|  |         using packed_t = at::Half;                                              \ | ||||||
|  |         return __VA_ARGS__();                                                   \ | ||||||
|  |       }                                                                         \ | ||||||
|  |       case at::ScalarType::Char : {                                             \ | ||||||
|  |         using packed_t = int8_t;                                                \ | ||||||
|  |         return __VA_ARGS__();                                                   \ | ||||||
|  |       }                                                                         \ | ||||||
|  |       case at::ScalarType::Float8_e4m3fn : {                                    \ | ||||||
|  |         using packed_t = at::Float8_e4m3fn;                                     \ | ||||||
|  |         return __VA_ARGS__();                                                   \ | ||||||
|  |       }                                                                         \ | ||||||
|  |       default:                                                                  \ | ||||||
|  |         TORCH_CHECK(false, "Unsupported floating data type.\n");                \ | ||||||
|  |     }                                                                           \ | ||||||
|  |   }() | ||||||
|  |  | ||||||
|  | #define UNUSED(x) (void)(x) | ||||||
|  |  | ||||||
|  | #define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") | ||||||
|  |  | ||||||
|  | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||||||
|  | #define CHECK_LAST_DIM_CONTIGUOUS(x) \ | ||||||
|  |   TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") | ||||||
|  |  | ||||||
|  | #define CHECK_INPUT(x) \ | ||||||
|  |   CHECK_CPU(x);        \ | ||||||
|  |   CHECK_CONTIGUOUS(x) | ||||||
|  | #define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ | ||||||
|  |   CHECK_CPU(x);                            \ | ||||||
|  |   CHECK_LAST_DIM_CONTIGUOUS(x) | ||||||
|  |  | ||||||
|  | #define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") | ||||||
|  |  | ||||||
|  | #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) | ||||||
|  |  | ||||||
|  | // parallel routines | ||||||
|  | constexpr int GRAIN_SIZE = 1024; | ||||||
|  |  | ||||||
|  | template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0> | ||||||
|  | inline T div_up(T x, T y) { return (x + y - 1) / y; } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { | ||||||
|  | #if 0 | ||||||
|  |     // onednn partition pattern | ||||||
|  |     T& n_my = n_end; | ||||||
|  |     if (nth <= 1 || n == 0) { | ||||||
|  |         n_start = 0; | ||||||
|  |         n_my = n; | ||||||
|  |     } else { | ||||||
|  |         T n1 = div_up(n, nth); | ||||||
|  |         T n2 = n1 - 1; | ||||||
|  |         T T1 = n - n2 * nth; | ||||||
|  |         n_my = ith < T1 ? n1 : n2; | ||||||
|  |         n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; | ||||||
|  |     } | ||||||
|  |     n_end += n_start; | ||||||
|  | #else | ||||||
|  |     // pytorch aten partition pattern | ||||||
|  |     T n_my = div_up(n, nth); | ||||||
|  |     n_start = ith * n_my; | ||||||
|  |     n_end = std::min(n_start + n_my, n); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename func_t> | ||||||
|  | inline void parallel_for(int n, const func_t& f) { | ||||||
|  | #if defined(_OPENMP) | ||||||
|  | #pragma omp parallel | ||||||
|  | { | ||||||
|  |     int nth = omp_get_num_threads(); | ||||||
|  |     int ith = omp_get_thread_num(); | ||||||
|  |     int tbegin, tend; | ||||||
|  |     balance211(n, nth, ith, tbegin, tend); | ||||||
|  |     f(tbegin, tend); | ||||||
|  | } | ||||||
|  | #else | ||||||
|  |     f(0, n); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // for 1d parallel, use `actual_nth` | ||||||
|  | // for 2d parallel, use even nths, e.g. 43->42 | ||||||
|  | int inline adjust_num_threads(int m) { | ||||||
|  |   int actual_nth = at::get_num_threads(); | ||||||
|  |   if (m == 1) { | ||||||
|  |     return actual_nth; | ||||||
|  |   } | ||||||
|  |   return std::max(1, (actual_nth >> 1) * 2); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename func_t> | ||||||
|  | inline void parallel_2d(int m, int n, const func_t& f) { | ||||||
|  |  | ||||||
|  |   // make sure we have even num_threads | ||||||
|  |   int nth = adjust_num_threads(m); | ||||||
|  |  | ||||||
|  |   // [NOTE] thread blocking: | ||||||
|  |   // | ||||||
|  |   //   1) prefer square block per thread | ||||||
|  |   //   2) use even number of CPU cores | ||||||
|  |   //   3) use all `num_threads` cores | ||||||
|  |   // | ||||||
|  |   //   we have: | ||||||
|  |   //     TM * TN = T | ||||||
|  |   //     BM / TM = BN / TN | ||||||
|  |   //   then: | ||||||
|  |   //     TM = ((BM / BN) * T) ^ 0.5 | ||||||
|  |   // | ||||||
|  |   float r = float(m) / n; | ||||||
|  |   int nth_m = std::ceil(std::sqrt(r * nth)); | ||||||
|  |   int nth_n = 1; | ||||||
|  |   for (; nth_m > 0; --nth_m) { | ||||||
|  |     nth_n = nth / nth_m; | ||||||
|  |     if (nth_m * nth_n == nth) { | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | #if defined(_OPENMP) | ||||||
|  | #pragma omp parallel num_threads(nth) | ||||||
|  | { | ||||||
|  |   int ith = omp_get_thread_num(); | ||||||
|  |   int ith_m = ith / nth_n; | ||||||
|  |   int ith_n = ith % nth_n; | ||||||
|  |  | ||||||
|  |   int thread_block_m = div_up(m, nth_m); | ||||||
|  |   int thread_block_n = div_up(n, nth_n); | ||||||
|  |  | ||||||
|  |   int begin_m = ith_m * thread_block_m; | ||||||
|  |   int end_m = std::min(m, begin_m + thread_block_m); | ||||||
|  |   int begin_n = ith_n * thread_block_n; | ||||||
|  |   int end_n = std::min(n, begin_n + thread_block_n); | ||||||
|  |  | ||||||
|  |   f(begin_m, end_m, begin_n, end_n); | ||||||
|  | } | ||||||
|  | #else | ||||||
|  |   f(0, m, 0, n); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | int get_cache_blocks(int BLOCK_SIZE, int K) { | ||||||
|  |   // L2 2MB and ratio of 50% | ||||||
|  |   const int L2_size = 2048 * 1024 >> 1; | ||||||
|  |   return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T)))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // data indexing for dimension collapse | ||||||
|  | template <typename T> | ||||||
|  | inline T data_index_init(T offset) { | ||||||
|  |   return offset; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, typename... Args> | ||||||
|  | inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { | ||||||
|  |   offset = data_index_init(offset, std::forward<Args>(args)...); | ||||||
|  |   x = offset % X; | ||||||
|  |   return offset / X; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline bool data_index_step() { | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, typename... Args> | ||||||
|  | inline bool data_index_step(T& x, const T& X, Args&&... args) { | ||||||
|  |   if (data_index_step(std::forward<Args>(args)...)) { | ||||||
|  |     x = ((x + 1) == X) ? 0 : (x + 1); | ||||||
|  |     return x == 0; | ||||||
|  |   } | ||||||
|  |   return false; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // forced unroll for perf critical path | ||||||
|  |  | ||||||
|  | #if __has_attribute(always_inline) | ||||||
|  | #define ALWAYS_INLINE __attribute__((__always_inline__)) inline | ||||||
|  | #else | ||||||
|  | #define ALWAYS_INLINE inline | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | template <int n> | ||||||
|  | struct Unroll { | ||||||
|  |   template <typename Func, typename... Args> | ||||||
|  |   ALWAYS_INLINE void operator()(const Func& f, Args... args) const { | ||||||
|  |     Unroll<n - 1>{}(f, args...); | ||||||
|  |     f(std::integral_constant<int, n - 1>{}, args...); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | struct Unroll<1> { | ||||||
|  |   template <typename Func, typename... Args> | ||||||
|  |   ALWAYS_INLINE void operator()(const Func& f, Args... args) const { | ||||||
|  |     f(std::integral_constant<int, 0>{}, args...); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
							
								
								
									
										464
									
								
								csrc/cpu/sgl-kernels/gemm.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										464
									
								
								csrc/cpu/sgl-kernels/gemm.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,464 @@ | |||||||
|  | // Adapted from | ||||||
|  | // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||||
|  |  | ||||||
|  | #include "common.h" | ||||||
|  | #include "vec.h" | ||||||
|  | #include "gemm.h" | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | // packed   layout: | ||||||
|  | //   quants {N, K}  int8_t | ||||||
|  | //   comp   {N}     int32_t | ||||||
|  | template <int BLOCK_N> | ||||||
|  | inline void s8s8_compensation(int8_t* __restrict__ packed, int K) { | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  |   constexpr int COLS = BLOCK_N / 16; | ||||||
|  |   __m512i vcomp[COLS]; | ||||||
|  |  | ||||||
|  |   for (int col = 0; col < COLS; ++col) { | ||||||
|  |     vcomp[col] = _mm512_setzero_si512(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   const int64_t offset = BLOCK_N * K; | ||||||
|  |   const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); | ||||||
|  |   for (int k = 0; k < K / 4; ++k) { | ||||||
|  |     for (int col = 0; col < COLS; ++col) { | ||||||
|  |       __m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64)); | ||||||
|  |       vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   for (int col = 0; col < COLS; ++col) { | ||||||
|  |     _mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]); | ||||||
|  |   } | ||||||
|  | #else | ||||||
|  |   TORCH_CHECK(false, "s8s8_compensation not implemented!"); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // convert to vnni format | ||||||
|  | // from [N, K] to [K/2, N, 2] for bfloat16 and float16 | ||||||
|  | template <typename packed_t> | ||||||
|  | inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) { | ||||||
|  |   const int VNNI_BLK = 2; | ||||||
|  |   for (int n = 0; n < N; ++n) { | ||||||
|  |     for (int k = 0; k < K / VNNI_BLK; ++k) { | ||||||
|  |       for (int d = 0; d < VNNI_BLK; ++d) { | ||||||
|  |         packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) { | ||||||
|  |   constexpr int BLOCK_N = block_size_n(); | ||||||
|  |   TORCH_CHECK(N == BLOCK_N); | ||||||
|  |  | ||||||
|  |   const int VNNI_BLK = 4; | ||||||
|  |   for (int n = 0; n < N; ++n) { | ||||||
|  |     for (int k = 0; k < K / VNNI_BLK; ++k) { | ||||||
|  |       for (int d = 0; d < VNNI_BLK; ++d) { | ||||||
|  |         packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   s8s8_compensation<BLOCK_N>(packed, K); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |  | ||||||
|  |   int64_t d; | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     fVec data0 = fVec::loadu(input + d); | ||||||
|  |     fVec data1 = fVec::loadu(input + d + fVec::size()); | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d]); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |  | ||||||
|  |   int64_t d; | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); | ||||||
|  |     fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d] + bias[d]); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_nn { | ||||||
|  |   static inline void apply( | ||||||
|  |       const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, | ||||||
|  |       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  | template <bool has_bias, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> { | ||||||
|  |   static inline void apply( | ||||||
|  |       const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C, | ||||||
|  |       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |  | ||||||
|  |     constexpr int ROWS = BLOCK_M; | ||||||
|  |     constexpr int COLS = BLOCK_N / 16; | ||||||
|  |  | ||||||
|  |     // prefetch distance | ||||||
|  |     constexpr int PREFETCH_SIZE_K = 0; | ||||||
|  |  | ||||||
|  |     __m512bh va; | ||||||
|  |     __m512bh vb[COLS]; | ||||||
|  |     __m512 vc[ROWS * COLS]; | ||||||
|  |  | ||||||
|  |     auto loadc = [&](auto i) { | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |       if constexpr (has_bias) { | ||||||
|  |         vc[i] = _mm512_loadu_ps(bias + col * 16); | ||||||
|  |       } else { | ||||||
|  |         vc[i] = _mm512_set1_ps(0.f); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(loadc); | ||||||
|  |  | ||||||
|  |     const int64_t K2 = K >> 1; | ||||||
|  |     const int64_t lda2 = lda >> 1; | ||||||
|  |     const int64_t ldb2 = ldb; // ldb * 2 >> 1; | ||||||
|  |     const float* a_ptr = reinterpret_cast<const float*>(A); | ||||||
|  |     const float* b_ptr = reinterpret_cast<const float*>(B); | ||||||
|  |  | ||||||
|  |     auto compute = [&](auto i, int64_t k) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       if constexpr (col == 0) { | ||||||
|  |         va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); | ||||||
|  |       } | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); | ||||||
|  |         if constexpr (PREFETCH_SIZE_K > 0) { | ||||||
|  |           _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); | ||||||
|  |     }; | ||||||
|  |     for (int64_t k = 0; k < K2; ++k) { | ||||||
|  |       Unroll<ROWS * COLS>{}(compute, k); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto storec = [&](auto i) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |       // for COLS = 2, 4 use 512bit store | ||||||
|  |       // for COLS = 1, 3 use 256bit store | ||||||
|  |       if constexpr (COLS % 2 == 0) { | ||||||
|  |         if constexpr (col % 2 == 0) { | ||||||
|  |           _mm512_storeu_si512( | ||||||
|  |               reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||||
|  |               (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); | ||||||
|  |         } | ||||||
|  |       } else { | ||||||
|  |         _mm256_storeu_si256( | ||||||
|  |             reinterpret_cast<__m256i*>(C + row * ldc + col * 16), | ||||||
|  |             (__m256i)(_mm512_cvtneps_pbh(vc[i]))); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(storec); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE)                          \ | ||||||
|  |     tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply(         \ | ||||||
|  |         A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ | ||||||
|  |         has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); | ||||||
|  |  | ||||||
|  | template <typename scalar_t, bool has_bias> | ||||||
|  | struct brgemm { | ||||||
|  |   static inline void apply( | ||||||
|  |       const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, | ||||||
|  |       float* __restrict__ Ctmp, const float* __restrict__ bias, | ||||||
|  |       int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |  | ||||||
|  |     constexpr int BLOCK_N = block_size_n(); | ||||||
|  |     at::native::cpublas::brgemm( | ||||||
|  |         M, N, K, lda, ldb, BLOCK_N, /* add_C */false, | ||||||
|  |         A, B, Ctmp); | ||||||
|  |  | ||||||
|  |     // copy from Ctmp to C | ||||||
|  |     for (int64_t m = 0; m < M; ++m) { | ||||||
|  |       if constexpr (has_bias) { | ||||||
|  |         copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); | ||||||
|  |       } else { | ||||||
|  |         copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename scalar_t, bool has_bias> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const scalar_t* __restrict__ A, | ||||||
|  |     const scalar_t* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     float* __restrict__ Ctmp, | ||||||
|  |     const float* __restrict__ bias, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     bool brg) { | ||||||
|  |  | ||||||
|  |   if (brg) { | ||||||
|  |     brgemm<scalar_t, has_bias>::apply( | ||||||
|  |         A, B, C, Ctmp, bias, | ||||||
|  |         M, N, K, lda, ldb, ldc); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // pattern: 1-4-16 | ||||||
|  |   constexpr int64_t BLOCK_M = 4; | ||||||
|  |   constexpr int64_t BLOCK_N = 64; | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |   for (int mb = 0; mb < MB; ++mb) { | ||||||
|  |     int64_t mb_start = mb * BLOCK_M; | ||||||
|  |     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||||
|  |     for (int64_t nb = 0; nb < NB; ++nb) { | ||||||
|  |       int64_t nb_start = nb * BLOCK_N; | ||||||
|  |       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||||
|  |  | ||||||
|  |       switch(mb_size << 4 | nb_size >> 4) { | ||||||
|  |         // mb_size = 1 | ||||||
|  |         case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; | ||||||
|  |         case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; | ||||||
|  |         // mb_size = 2 | ||||||
|  |         case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; | ||||||
|  |         case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; | ||||||
|  |         // mb_size = 3 | ||||||
|  |         case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; | ||||||
|  |         case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; | ||||||
|  |         // mb_size = 4 | ||||||
|  |         case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; | ||||||
|  |         case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; | ||||||
|  |         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void weight_packed_linear_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ out, | ||||||
|  |     const scalar_t* __restrict__ mat1, | ||||||
|  |     const scalar_t* __restrict__ mat2, | ||||||
|  |     const float* __restrict__ bias, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t mat1_strideM, | ||||||
|  |     int64_t out_strideM) { | ||||||
|  |  | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m(); | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |  | ||||||
|  |   // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx | ||||||
|  |   const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>); | ||||||
|  |  | ||||||
|  |   // l2 cache block for n | ||||||
|  |   int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K); | ||||||
|  |  | ||||||
|  |   // parallel on [MB, NB] | ||||||
|  |   AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { | ||||||
|  |     parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { | ||||||
|  |  | ||||||
|  |       // for brgemm, use float32 for accumulate | ||||||
|  |       alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; | ||||||
|  |  | ||||||
|  |       for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { | ||||||
|  |       for (int64_t mb = begin_mb; mb < end_mb; ++mb) { | ||||||
|  |       for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { | ||||||
|  |  | ||||||
|  |         int64_t mb_start = mb * BLOCK_M; | ||||||
|  |         int64_t mb_size = std::min(M - mb_start, BLOCK_M); | ||||||
|  |         int64_t nb_start = nb * BLOCK_N; | ||||||
|  |         int64_t nb_size = std::min(N - nb_start, BLOCK_N); | ||||||
|  |  | ||||||
|  |         tinygemm_kernel<scalar_t, has_bias>( | ||||||
|  |             /*   A */ mat1 + mb_start * mat1_strideM, | ||||||
|  |             /*   B */ mat2 + nb_start * K /* nb * BLOCK_N * K */, | ||||||
|  |             /*   C */ out + mb_start * out_strideM + nb_start, | ||||||
|  |             /* Ctmp*/ Ctmp, | ||||||
|  |             /* bias*/ bias + nb_start, | ||||||
|  |             /*   M */ mb_size, | ||||||
|  |             /*   N */ nb_size, | ||||||
|  |             /*   K */ K, | ||||||
|  |             /* lda */ mat1_strideM, | ||||||
|  |             /* ldb */ nb_size, | ||||||
|  |             /* ldc */ out_strideM, | ||||||
|  |             /* brg */ use_brgemm); | ||||||
|  |       }}} | ||||||
|  |  | ||||||
|  |       if (use_brgemm) { | ||||||
|  |         at::native::cpublas::brgemm_release(); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
|  |  | ||||||
|  | // tinygemm interface | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, | ||||||
|  |     float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { | ||||||
|  |   tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE)                                             \ | ||||||
|  |     template void tinygemm_kernel<TYPE>(                                                \ | ||||||
|  |         const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C,   \ | ||||||
|  |         float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda,         \ | ||||||
|  |         int64_t ldb, int64_t ldc, bool brg) | ||||||
|  |  | ||||||
|  | INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); | ||||||
|  | INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); | ||||||
|  |  | ||||||
|  | at::Tensor convert_weight_packed(at::Tensor& weight) { | ||||||
|  |   // for 3d moe weights | ||||||
|  |   // weight : [E, OC, IC] | ||||||
|  |   //     w1 : [E, 2N,  K] | ||||||
|  |   //     w2 : [E,  K,  N] | ||||||
|  |   CHECK_INPUT(weight); | ||||||
|  |  | ||||||
|  |   const int64_t ndim = weight.ndimension(); | ||||||
|  |   TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor."); | ||||||
|  |   const auto st = weight.scalar_type(); | ||||||
|  |   const int64_t E = ndim == 3 ? weight.size(0) : 1; | ||||||
|  |   const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); | ||||||
|  |   const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); | ||||||
|  |  | ||||||
|  |   // we handle 2 TILE_N at a time. | ||||||
|  |   TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); | ||||||
|  |   TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); | ||||||
|  |  | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |   const int64_t NB = div_up(OC, BLOCK_N); | ||||||
|  |  | ||||||
|  |   // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2] | ||||||
|  |   auto packed_weight = at::empty({}, weight.options()); | ||||||
|  |   const int64_t stride = OC * IC; | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, | ||||||
|  |       "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); | ||||||
|  |  | ||||||
|  |   CPU_DISPATCH_PACKED_TYPES(st, [&] { | ||||||
|  |     // adjust most inner dimension size | ||||||
|  |     const int packed_row_size = get_row_size<packed_t>(IC); | ||||||
|  |     auto sizes = weight.sizes().vec(); | ||||||
|  |     sizes[ndim - 1] = packed_row_size; | ||||||
|  |     packed_weight.resize_(sizes); | ||||||
|  |  | ||||||
|  |     const packed_t* w_data = weight.data_ptr<packed_t>(); | ||||||
|  |     packed_t* packed_data = packed_weight.data_ptr<packed_t>(); | ||||||
|  |  | ||||||
|  |     // parallel on {E, NB} | ||||||
|  |     at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |       int64_t e{0}, nb{0}; | ||||||
|  |       data_index_init(begin, e, E, nb, NB); | ||||||
|  |  | ||||||
|  |       for (int64_t i = begin; i < end; ++i) { | ||||||
|  |         UNUSED(i); | ||||||
|  |  | ||||||
|  |         int64_t n = nb * BLOCK_N; | ||||||
|  |         int64_t n_size = std::min(BLOCK_N, OC - n); | ||||||
|  |         pack_vnni<packed_t>( | ||||||
|  |             packed_data + e * OC * packed_row_size + n * packed_row_size, | ||||||
|  |             w_data + e * stride + n * IC, | ||||||
|  |             n_size, | ||||||
|  |             IC); | ||||||
|  |  | ||||||
|  |         // move to the next index | ||||||
|  |         data_index_step(e, E, nb, NB); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |   }); | ||||||
|  |   return packed_weight; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // mat1 : [M, K] | ||||||
|  | // mat2 : [N, K] | ||||||
|  | // bias : [N] | ||||||
|  | // out  : [M, N] | ||||||
|  | // | ||||||
|  | at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, | ||||||
|  |     const std::optional<at::Tensor>& bias, bool is_vnni) { | ||||||
|  |   RECORD_FUNCTION( | ||||||
|  |     "sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias})); | ||||||
|  |  | ||||||
|  |   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||||
|  |  | ||||||
|  |   CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); | ||||||
|  |   CHECK_INPUT(mat2); | ||||||
|  |  | ||||||
|  |   int64_t M = mat1.size(0); | ||||||
|  |   int64_t N = mat2.size(0); | ||||||
|  |   int64_t K = mat2.size(1); | ||||||
|  |   CHECK_EQ(mat1.size(1), K); | ||||||
|  |   CHECK_DIM(2, mat1); | ||||||
|  |   CHECK_DIM(2, mat2); | ||||||
|  |  | ||||||
|  |   auto out = at::empty({M, N}, mat1.options()); | ||||||
|  |  | ||||||
|  |   // strides | ||||||
|  |   int64_t mat1_strideM = mat1.stride(0); | ||||||
|  |   int64_t out_strideM = out.stride(0); | ||||||
|  |  | ||||||
|  |   const bool has_bias = bias.has_value(); | ||||||
|  |   const float* bias_data = nullptr; | ||||||
|  |   if (has_bias) { | ||||||
|  |     CHECK_EQ(bias.value().size(0), N); | ||||||
|  |     bias_data = bias.value().data_ptr<float>(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] { | ||||||
|  |     weight_packed_linear_kernel_impl<scalar_t>( | ||||||
|  |         out.data_ptr<scalar_t>(), | ||||||
|  |         mat1.data_ptr<scalar_t>(), | ||||||
|  |         packed_w.data_ptr<scalar_t>(), | ||||||
|  |         bias_data, | ||||||
|  |         M, | ||||||
|  |         N, | ||||||
|  |         K, | ||||||
|  |         mat1_strideM, | ||||||
|  |         out_strideM); | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   return out; | ||||||
|  | } | ||||||
							
								
								
									
										266
									
								
								csrc/cpu/sgl-kernels/gemm.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										266
									
								
								csrc/cpu/sgl-kernels/gemm.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,266 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/native/CPUBlas.h> | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | // amx-bf16 | ||||||
|  | #define TILE_M 16 | ||||||
|  | #define TILE_N 16 | ||||||
|  | #define TILE_K 32 | ||||||
|  |  | ||||||
|  | // block size for AMX gemm | ||||||
|  | constexpr int block_size_m() { return 2 * TILE_M; } | ||||||
|  | constexpr int block_size_n() { return 2 * TILE_N; } | ||||||
|  |  | ||||||
|  | // define threshold using brgemm (intel AMX) | ||||||
|  | template <typename T> inline bool can_use_brgemm(int M); | ||||||
|  | template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; } | ||||||
|  | template <> inline bool can_use_brgemm<at::Half>(int M) { return true; } | ||||||
|  | // TODO: add u8s8 brgemm, this requires PyTorch 2.7 | ||||||
|  | template <> inline bool can_use_brgemm<int8_t>(int M) { return false; } | ||||||
|  | template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; } | ||||||
|  | template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; } | ||||||
|  |  | ||||||
|  | // work around compiler internal error | ||||||
|  | #define BLOCK_K 128 // 4 * TILE_K | ||||||
|  |  | ||||||
|  | // adjust leading dimension size for K | ||||||
|  | template <typename T> | ||||||
|  | inline int64_t get_row_size(int64_t K) { | ||||||
|  |   return K; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline int64_t get_row_size<int8_t>(int64_t K) { | ||||||
|  |   return K + sizeof(int32_t); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { | ||||||
|  |   return use_int8_w8a8 ? K + sizeof(int32_t) : K; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // pack weight to vnni format | ||||||
|  | at::Tensor convert_weight_packed(at::Tensor& weight); | ||||||
|  |  | ||||||
|  | // moe implementations for int8 w8a8 | ||||||
|  | template <typename scalar_t> | ||||||
|  | void fused_experts_int8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     scalar_t* __restrict__ ic2, | ||||||
|  |     uint8_t* __restrict__ A_tmp, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     uint8_t* __restrict__ Aq_tmp, | ||||||
|  |     float* __restrict__ As_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const int8_t* __restrict__ packed_w1, | ||||||
|  |     const int8_t* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     const float* __restrict__ topk_weights, | ||||||
|  |     const int32_t* __restrict__ sorted_ids, | ||||||
|  |     const int32_t* __restrict__ expert_ids, | ||||||
|  |     const int32_t* __restrict__ offsets, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t E, | ||||||
|  |     int64_t topk, | ||||||
|  |     int64_t num_tokens_post_pad); | ||||||
|  |  | ||||||
|  | // moe implementations for fp8 w8a16 | ||||||
|  | template <typename scalar_t> | ||||||
|  | void fused_experts_fp8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic0, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     scalar_t* __restrict__ ic2, | ||||||
|  |     scalar_t* __restrict__ A_tmp, | ||||||
|  |     scalar_t* __restrict__ B_tmp, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     int64_t block_size_N, | ||||||
|  |     int64_t block_size_K, | ||||||
|  |     const float* __restrict__ topk_weights, | ||||||
|  |     const int32_t* __restrict__ sorted_ids, | ||||||
|  |     const int32_t* __restrict__ expert_ids, | ||||||
|  |     const int32_t* __restrict__ offsets, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t E, | ||||||
|  |     int64_t topk, | ||||||
|  |     int64_t num_tokens_post_pad); | ||||||
|  |  | ||||||
|  | // moe implementations for int4 w4a16 | ||||||
|  | template <typename scalar_t> | ||||||
|  | void fused_experts_int4_w4a16_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic0, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     scalar_t* __restrict__ ic2, | ||||||
|  |     scalar_t* __restrict__ A_tmp, | ||||||
|  |     scalar_t* __restrict__ B_tmp, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const at::quint4x2* __restrict__ packed_w1, | ||||||
|  |     const at::quint4x2* __restrict__ packed_w2, | ||||||
|  |     const uint8_t* __restrict__ w1z, | ||||||
|  |     const uint8_t* __restrict__ w2z, | ||||||
|  |     const scalar_t* __restrict__ w1s, | ||||||
|  |     const scalar_t* __restrict__ w2s, | ||||||
|  |     int group_size, | ||||||
|  |     const float* __restrict__ topk_weights, | ||||||
|  |     const int32_t* __restrict__ sorted_ids, | ||||||
|  |     const int32_t* __restrict__ expert_ids, | ||||||
|  |     const int32_t* __restrict__ offsets, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t E, | ||||||
|  |     int64_t topk, | ||||||
|  |     int64_t num_tokens_post_pad); | ||||||
|  |  | ||||||
|  | // shared expert implementation for int8 w8a8 | ||||||
|  | template <typename scalar_t> | ||||||
|  | void shared_expert_int8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     uint8_t* __restrict__ Aq_tmp, | ||||||
|  |     float* __restrict__ As_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const int8_t* __restrict__ packed_w1, | ||||||
|  |     const int8_t* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     const scalar_t* __restrict__ fused_experts_out, | ||||||
|  |     float routed_scaling_factor, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void shared_expert_fp8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic0, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     scalar_t* __restrict__ B_tmp, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     int64_t block_size_N, | ||||||
|  |     int64_t block_size_K, | ||||||
|  |     const scalar_t* __restrict__ fused_experts_out, | ||||||
|  |     float routed_scaling_factor, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K); | ||||||
|  |  | ||||||
|  | // tinygemm interface | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const scalar_t* __restrict__ A, | ||||||
|  |     const scalar_t* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     float* __restrict__ Ctmp, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     bool brg); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const uint8_t* __restrict__ A, | ||||||
|  |     const int8_t* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     int32_t* __restrict__ Ctmp, | ||||||
|  |     const float* __restrict__ As, | ||||||
|  |     const float* __restrict__ Bs, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     bool brg); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const scalar_t* __restrict__ A, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     scalar_t* __restrict__ Btmp, | ||||||
|  |     float* __restrict__ Ctmp, | ||||||
|  |     const float* __restrict__ scale, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     bool brg, | ||||||
|  |     int64_t block_size_K); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const scalar_t* __restrict__ A, | ||||||
|  |     const at::quint4x2* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     const uint8_t* __restrict__ Bz, | ||||||
|  |     const scalar_t* __restrict__ Bs, | ||||||
|  |     scalar_t* __restrict__ Btmp, | ||||||
|  |     float* __restrict__ Ctmp, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int group_size, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     int64_t strideBz, | ||||||
|  |     int64_t strideBs, | ||||||
|  |     bool brg); | ||||||
|  |  | ||||||
|  | // TODO: debug print, remove me later | ||||||
|  | inline void print_16x32i(const __m512i x) { | ||||||
|  |   int32_t a[16]; | ||||||
|  |   _mm512_storeu_si512((__m512i *)a, x); | ||||||
|  |  | ||||||
|  |   for (int i = 0; i < 16; i++){ | ||||||
|  |     std::cout << a[i] << " "; | ||||||
|  |   } | ||||||
|  |   std::cout << std::endl; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline void print_16x32(const __m512 x) { | ||||||
|  |   float a[16]; | ||||||
|  |   _mm512_storeu_ps((__m512 *)a, x); | ||||||
|  |  | ||||||
|  |   for (int i = 0; i < 16; i++){ | ||||||
|  |     std::cout << a[i] << " "; | ||||||
|  |   } | ||||||
|  |   std::cout << std::endl; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | inline void print_32x8u(const __m256i x) { | ||||||
|  |   uint8_t a[32]; | ||||||
|  |   _mm256_storeu_si256((__m256i *)a, x); | ||||||
|  |  | ||||||
|  |   for (int i = 0; i < 32; ++i) { | ||||||
|  |     std::cout << int32_t(a[i]) << " "; | ||||||
|  |   } | ||||||
|  |   std::cout << std::endl; | ||||||
|  | } | ||||||
							
								
								
									
										530
									
								
								csrc/cpu/sgl-kernels/gemm_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										530
									
								
								csrc/cpu/sgl-kernels/gemm_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,530 @@ | |||||||
|  | // Adapted from | ||||||
|  | // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||||
|  |  | ||||||
|  | #include "common.h" | ||||||
|  | #include "vec.h" | ||||||
|  | #include "gemm.h" | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | // we use 4x32 for BLOCK_M | ||||||
|  | #define BLOCK_SIZE_M_SCALE 4 | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |  | ||||||
|  |   int64_t d; | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     fVec data0 = fVec::loadu(input + d); | ||||||
|  |     fVec data1 = fVec::loadu(input + d + fVec::size()); | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d]); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |  | ||||||
|  |   int64_t d; | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); | ||||||
|  |     fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d] + bias[d]); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline void unpack_B( | ||||||
|  |     at::BFloat16* __restrict__ Btmp, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_B, | ||||||
|  |     int N, | ||||||
|  |     int K, | ||||||
|  |     int ldb, | ||||||
|  |     int ldb_tmp, | ||||||
|  |     float scale) { | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  |   // [K/2, N, 2] | ||||||
|  |   const int K2 = K >> 1; | ||||||
|  |   const int ldb2 = ldb; // ldb * 2 >> 1; | ||||||
|  |   const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B); | ||||||
|  |   const __m512 vd = _mm512_set1_ps(scale); | ||||||
|  |  | ||||||
|  |   constexpr int BLOCK_N = block_size_n(); | ||||||
|  |   static_assert(BLOCK_N == 32); | ||||||
|  |  | ||||||
|  |   // prefetch distance | ||||||
|  |   constexpr int PREFETCH_SIZE_K = 64; | ||||||
|  |  | ||||||
|  | #pragma GCC unroll 4 | ||||||
|  |   for (int k = 0; k < K2; ++k) { | ||||||
|  |     __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); | ||||||
|  |     if constexpr (PREFETCH_SIZE_K > 0) { | ||||||
|  |       _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); | ||||||
|  |     __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); | ||||||
|  |  | ||||||
|  |     __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); | ||||||
|  |     __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); | ||||||
|  |  | ||||||
|  |     // Apply scale | ||||||
|  |     __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); | ||||||
|  |     __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); | ||||||
|  |     __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); | ||||||
|  |     __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); | ||||||
|  |  | ||||||
|  |     f0_lo = _mm512_mul_ps(f0_lo, vd); | ||||||
|  |     f0_hi = _mm512_mul_ps(f0_hi, vd); | ||||||
|  |     f1_lo = _mm512_mul_ps(f1_lo, vd); | ||||||
|  |     f1_hi = _mm512_mul_ps(f1_hi, vd); | ||||||
|  |  | ||||||
|  |     bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); | ||||||
|  |     bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); | ||||||
|  |  | ||||||
|  |     _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); | ||||||
|  |     _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); | ||||||
|  |   } | ||||||
|  | #else | ||||||
|  |   TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_nn { | ||||||
|  |   static inline void apply( | ||||||
|  |       const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, | ||||||
|  |       const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { | ||||||
|  |     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  | template <bool has_bias, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> { | ||||||
|  |   static inline void apply( | ||||||
|  |       const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, | ||||||
|  |       const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { | ||||||
|  |  | ||||||
|  |     constexpr int ROWS = BLOCK_M; | ||||||
|  |     constexpr int COLS = BLOCK_N / 16; | ||||||
|  |  | ||||||
|  |     const int KB = div_up(K, BLOCK_K); | ||||||
|  |  | ||||||
|  |     // prefetch distance | ||||||
|  |     constexpr int PREFETCH_SIZE_K = 64; | ||||||
|  |     constexpr int PREFETCH_SIZE_KB = 1; | ||||||
|  |  | ||||||
|  |     __m512bh va; | ||||||
|  |     __m512bh vb[COLS]; | ||||||
|  |     __m512 vc[ROWS * COLS]; | ||||||
|  |     __m512 vsum[ROWS * COLS]; | ||||||
|  |  | ||||||
|  |     // block quant scale | ||||||
|  |     __m512 vscale; | ||||||
|  |  | ||||||
|  |     auto loadc = [&](auto i) { | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |       if constexpr (has_bias) { | ||||||
|  |         vc[i] = _mm512_loadu_ps(bias + col * 16); | ||||||
|  |       } else { | ||||||
|  |         vc[i] = _mm512_setzero_ps(); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(loadc); | ||||||
|  |  | ||||||
|  |     const int lda2 = lda >> 1; | ||||||
|  |     const int ldb2 = ldb; // ldb * 2 >> 1; | ||||||
|  |     const float* a_ptr = reinterpret_cast<const float*>(A); | ||||||
|  |     const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B); | ||||||
|  |  | ||||||
|  |     auto compute = [&](auto i, int k) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       if constexpr (col == 0) { | ||||||
|  |         va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); | ||||||
|  |         if constexpr (PREFETCH_SIZE_K > 0) { | ||||||
|  |           _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         if constexpr (col % 2 == 0) { | ||||||
|  |           __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16); | ||||||
|  |           if constexpr (PREFETCH_SIZE_K > 0) { | ||||||
|  |             _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); | ||||||
|  |           } | ||||||
|  |           vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); | ||||||
|  |           vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     constexpr int BLOCK_K2 = BLOCK_K >> 1; | ||||||
|  |     for (int kb = 0; kb < KB; ++kb) { | ||||||
|  |       int kb_start = kb * BLOCK_K2; | ||||||
|  |       int kb_end = std::min(K, kb_start + BLOCK_K2); | ||||||
|  |       // 1. load scale vector | ||||||
|  |       vscale = _mm512_set1_ps(scale[kb]); | ||||||
|  |       if constexpr (PREFETCH_SIZE_KB > 0) { | ||||||
|  |         _mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0); | ||||||
|  |       } | ||||||
|  |       // 2. zero vsum for each block | ||||||
|  |       Unroll<ROWS * COLS>{}([&](auto i) { | ||||||
|  |         vsum[i] = _mm512_setzero_ps(); | ||||||
|  |       }); | ||||||
|  |       // 3. accumulate across each block | ||||||
|  |       for (int k = kb_start; k < kb_end; ++k) { | ||||||
|  |         Unroll<ROWS * COLS>{}(compute, k); | ||||||
|  |       } | ||||||
|  |       // 4. apply scale | ||||||
|  |       Unroll<ROWS * COLS>{}([&](auto i) { | ||||||
|  |         vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); | ||||||
|  |       }); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto storec = [&](auto i) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |       // for COLS = 2,4 use 512bit store | ||||||
|  |       if constexpr (col % 2 == 0) { | ||||||
|  |         _mm512_storeu_si512( | ||||||
|  |             reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||||
|  |             (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(storec); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE)                          \ | ||||||
|  |     tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply(         \ | ||||||
|  |         A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ | ||||||
|  |         has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K); | ||||||
|  |  | ||||||
|  | template <typename scalar_t, typename packed_t, bool has_bias> | ||||||
|  | struct brgemm { | ||||||
|  |   static inline void apply( | ||||||
|  |       const scalar_t* __restrict__ A, | ||||||
|  |       const packed_t* __restrict__ B, | ||||||
|  |       scalar_t* __restrict__ C, | ||||||
|  |       scalar_t* __restrict__ Btmp, | ||||||
|  |       float* __restrict__ Ctmp, | ||||||
|  |       const float* __restrict__ bias, | ||||||
|  |       const float* __restrict__ scale, | ||||||
|  |       int M, | ||||||
|  |       int N, | ||||||
|  |       int K, | ||||||
|  |       int lda, | ||||||
|  |       int ldb, | ||||||
|  |       int ldc) { | ||||||
|  |     TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <bool has_bias> | ||||||
|  | struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> { | ||||||
|  |   static inline void apply( | ||||||
|  |       const at::BFloat16* __restrict__ A, | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ B, | ||||||
|  |       at::BFloat16* __restrict__ C, | ||||||
|  |       at::BFloat16* __restrict__ Btmp, | ||||||
|  |       float* __restrict__ Ctmp, | ||||||
|  |       const float* __restrict__ bias, | ||||||
|  |       const float* __restrict__ scale, | ||||||
|  |       int M, | ||||||
|  |       int N, | ||||||
|  |       int K, | ||||||
|  |       int lda, | ||||||
|  |       int ldb, | ||||||
|  |       int ldc) { | ||||||
|  |  | ||||||
|  |     constexpr int BLOCK_N = block_size_n(); | ||||||
|  |  | ||||||
|  |     // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] | ||||||
|  |     const int ldb_tmp = BLOCK_N; | ||||||
|  |  | ||||||
|  |     for (int k = 0; k < K; k += BLOCK_K) { | ||||||
|  |       int kb_size = std::min(BLOCK_K, K - k); | ||||||
|  |  | ||||||
|  |       int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 | ||||||
|  |       unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     at::native::cpublas::brgemm( | ||||||
|  |         M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); | ||||||
|  |  | ||||||
|  |     // copy from Ctmp to C | ||||||
|  |     for (int m = 0; m < M; ++m) { | ||||||
|  |       if constexpr (has_bias) { | ||||||
|  |         copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); | ||||||
|  |       } else { | ||||||
|  |         copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <typename scalar_t, bool has_bias> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const scalar_t* __restrict__ A, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     scalar_t* __restrict__ Btmp, | ||||||
|  |     float* __restrict__ Ctmp, | ||||||
|  |     const float* __restrict__ scale, | ||||||
|  |     const float* __restrict__ bias, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     bool brg, | ||||||
|  |     int64_t block_size_K) { | ||||||
|  |  | ||||||
|  |   if (brg) { | ||||||
|  |     brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply( | ||||||
|  |         A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // pattern: 1-4-16 | ||||||
|  |   constexpr int64_t BLOCK_M = 4; | ||||||
|  |   constexpr int64_t BLOCK_N = 64; | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |   for (int mb = 0; mb < MB; ++mb) { | ||||||
|  |     int64_t mb_start = mb * BLOCK_M; | ||||||
|  |     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||||
|  |     for (int64_t nb = 0; nb < NB; ++nb) { | ||||||
|  |       int64_t nb_start = nb * BLOCK_N; | ||||||
|  |       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||||
|  |  | ||||||
|  |       switch(mb_size << 4 | nb_size >> 4) { | ||||||
|  |         case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; | ||||||
|  |         case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; | ||||||
|  |         case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; | ||||||
|  |         case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; | ||||||
|  |         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void fp8_scaled_mm_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ out, | ||||||
|  |     const scalar_t* __restrict__ mat1, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ mat2, | ||||||
|  |     const float* __restrict__ scales2, | ||||||
|  |     const float* __restrict__ bias, | ||||||
|  |     scalar_t* __restrict__ buffer, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t mat1_strideM, | ||||||
|  |     int64_t out_strideM, | ||||||
|  |     int64_t block_size_N, | ||||||
|  |     int64_t block_size_K, | ||||||
|  |     int64_t buffer_size_per_thread) { | ||||||
|  |  | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |  | ||||||
|  |   const int64_t scale_size_K = div_up(K, block_size_K); | ||||||
|  |   const int64_t blocks_n_per_group = block_size_N / BLOCK_N; | ||||||
|  |  | ||||||
|  |   const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M); | ||||||
|  |  | ||||||
|  |   // parallel on [MB, NB] | ||||||
|  |   AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { | ||||||
|  |     at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |       int64_t mb{0}, nb{0}; | ||||||
|  |       data_index_init(begin, mb, MB, nb, NB); | ||||||
|  |  | ||||||
|  |       int tid = at::get_thread_num(); | ||||||
|  |       scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; | ||||||
|  |       float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); | ||||||
|  |  | ||||||
|  |       for (int64_t i = begin; i < end; ++i) { | ||||||
|  |         UNUSED(i); | ||||||
|  |         const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; | ||||||
|  |  | ||||||
|  |         int64_t mb_start = mb * BLOCK_M; | ||||||
|  |         int64_t mb_size = std::min(M - mb_start, BLOCK_M); | ||||||
|  |         int64_t nb_start = nb * BLOCK_N; | ||||||
|  |         int64_t nb_size = std::min(N - nb_start, BLOCK_N); | ||||||
|  |  | ||||||
|  |         tinygemm_kernel<scalar_t, has_bias>( | ||||||
|  |             /*   A            */ mat1 + mb_start * mat1_strideM, | ||||||
|  |             /*   B            */ mat2 + nb_start * K, // nb * BLOCK_N * K | ||||||
|  |             /*   C            */ out + mb_start * out_strideM + nb_start, | ||||||
|  |             /*   Btmp         */ Btmp, | ||||||
|  |             /*   Ctmp         */ Ctmp, | ||||||
|  |             /*   scale        */ scale_ptr, | ||||||
|  |             /*   bias         */ bias + nb_start, | ||||||
|  |             /*   M            */ mb_size, | ||||||
|  |             /*   N            */ nb_size, | ||||||
|  |             /*   K            */ K, | ||||||
|  |             /*   lda          */ mat1_strideM, | ||||||
|  |             /*   ldb          */ nb_size, | ||||||
|  |             /*   ldc          */ out_strideM, | ||||||
|  |             /*   brg          */ use_brgemm, | ||||||
|  |             /*   block_size_K */ block_size_K); | ||||||
|  |  | ||||||
|  |         // move to the next index | ||||||
|  |         data_index_step(mb, MB, nb, NB); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       if (use_brgemm) { | ||||||
|  |         at::native::cpublas::brgemm_release(); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
|  |  | ||||||
|  | // tinygemm interface | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const scalar_t* __restrict__ A, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     scalar_t* __restrict__ Btmp, | ||||||
|  |     float* __restrict__ Ctmp, | ||||||
|  |     const float* __restrict__ scale, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     bool brg, | ||||||
|  |     int64_t block_size_K) { | ||||||
|  |   tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE)    \ | ||||||
|  |   template void tinygemm_kernel<TYPE>(         \ | ||||||
|  |       const TYPE* __restrict__ A,              \ | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ B, \ | ||||||
|  |       TYPE* __restrict__ C,                    \ | ||||||
|  |       TYPE* __restrict__ Btmp,                 \ | ||||||
|  |       float* __restrict__ Ctmp,                \ | ||||||
|  |       const float* __restrict__ scale,         \ | ||||||
|  |       int64_t M,                               \ | ||||||
|  |       int64_t N,                               \ | ||||||
|  |       int64_t K,                               \ | ||||||
|  |       int64_t lda,                             \ | ||||||
|  |       int64_t ldb,                             \ | ||||||
|  |       int64_t ldc,                             \ | ||||||
|  |       bool brg,                                \ | ||||||
|  |       int64_t block_size_K) | ||||||
|  |  | ||||||
|  | INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); | ||||||
|  | INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); | ||||||
|  |  | ||||||
|  | at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, | ||||||
|  |     std::vector<int64_t> block_size, std::optional<at::Tensor>& bias, | ||||||
|  |     at::ScalarType out_dtype, bool is_vnni) { | ||||||
|  |   RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias})); | ||||||
|  |  | ||||||
|  |   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||||
|  |  | ||||||
|  |   CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); | ||||||
|  |   CHECK_INPUT(mat2); | ||||||
|  |   CHECK_INPUT(scales2); | ||||||
|  |   TORCH_CHECK(scales2.scalar_type() == at::kFloat, | ||||||
|  |       "fp8_scaled_mm_cpu: expect scales2 to be float32."); | ||||||
|  |  | ||||||
|  |   int64_t M = mat1.size(0); | ||||||
|  |   int64_t N = mat2.size(0); | ||||||
|  |   int64_t K = mat2.size(1); | ||||||
|  |  | ||||||
|  |   CHECK_EQ(mat1.size(1), K); | ||||||
|  |   CHECK_DIM(2, mat1); | ||||||
|  |   CHECK_DIM(2, mat2); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(block_size.size() == 2, | ||||||
|  |       "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); | ||||||
|  |  | ||||||
|  |   int64_t block_size_N = block_size[0]; | ||||||
|  |   int64_t block_size_K = block_size[1]; | ||||||
|  |  | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |   TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); | ||||||
|  |   TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); | ||||||
|  |   CHECK_EQ(scales2.size(0), div_up(N, block_size_N)); | ||||||
|  |   CHECK_EQ(scales2.size(1), div_up(K, block_size_K)); | ||||||
|  |  | ||||||
|  |   const auto st = mat1.scalar_type(); | ||||||
|  |   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, | ||||||
|  |       "fp8_scaled_mm_cpu: expect A to be bfloat16 or half."); | ||||||
|  |   TORCH_CHECK(st == out_dtype, | ||||||
|  |       "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype."); | ||||||
|  |   TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, | ||||||
|  |       "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3."); | ||||||
|  |   TORCH_CHECK(scales2.scalar_type() == at::kFloat, | ||||||
|  |       "fp8_scaled_mm_cpu: expect scales to be float32."); | ||||||
|  |   auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); | ||||||
|  |  | ||||||
|  |   // strides | ||||||
|  |   int64_t mat1_strideM = mat1.stride(0); | ||||||
|  |   int64_t out_strideM = out.stride(0); | ||||||
|  |  | ||||||
|  |   const bool has_bias = bias.has_value(); | ||||||
|  |   const float* bias_data = nullptr; | ||||||
|  |   if (has_bias) { | ||||||
|  |     CHECK_EQ(bias.value().size(0), N); | ||||||
|  |     bias_data = bias.value().data_ptr<float>(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // Btmp : [T, BLOCK_N * K] | ||||||
|  |   // Ctmp : [T, BLOCK_M * BLOCK_N] | ||||||
|  |   int num_threads = at::get_num_threads(); | ||||||
|  |   int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; | ||||||
|  |   auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); | ||||||
|  |  | ||||||
|  |   AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { | ||||||
|  |     fp8_scaled_mm_kernel_impl<scalar_t>( | ||||||
|  |         out.data_ptr<scalar_t>(), | ||||||
|  |         mat1.data_ptr<scalar_t>(), | ||||||
|  |         packed_w.data_ptr<at::Float8_e4m3fn>(), | ||||||
|  |         scales2.data_ptr<float>(), | ||||||
|  |         bias_data, | ||||||
|  |         buffer.data_ptr<scalar_t>(), | ||||||
|  |         M, | ||||||
|  |         N, | ||||||
|  |         K, | ||||||
|  |         mat1_strideM, | ||||||
|  |         out_strideM, | ||||||
|  |         block_size_N, | ||||||
|  |         block_size_K, | ||||||
|  |         size_per_thread); | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   return out; | ||||||
|  | } | ||||||
							
								
								
									
										440
									
								
								csrc/cpu/sgl-kernels/gemm_int8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										440
									
								
								csrc/cpu/sgl-kernels/gemm_int8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,440 @@ | |||||||
|  | // Adapted from | ||||||
|  | // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||||
|  |  | ||||||
|  | #include "common.h" | ||||||
|  | #include "vec.h" | ||||||
|  | #include "gemm.h" | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_nn { | ||||||
|  |   static inline void apply( | ||||||
|  |       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, | ||||||
|  |       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||||
|  |       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  | template <bool has_bias, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> { | ||||||
|  |   static inline void apply( | ||||||
|  |       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C, | ||||||
|  |       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||||
|  |       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |  | ||||||
|  |     constexpr int ROWS = BLOCK_M; | ||||||
|  |     constexpr int COLS = BLOCK_N / 16; | ||||||
|  |     static_assert(COLS % 2 == 0); | ||||||
|  |  | ||||||
|  |     // prefetch distance | ||||||
|  |     constexpr int PREFETCH_SIZE_K = 0; | ||||||
|  |  | ||||||
|  |     __m512i va; | ||||||
|  |     __m512i vb[COLS]; | ||||||
|  |     __m512i vc[ROWS * COLS]; | ||||||
|  |     __m512i vcomp[COLS]; | ||||||
|  |     __m512  vd0; | ||||||
|  |     __m512  vd1[COLS]; | ||||||
|  |  | ||||||
|  |     // oops! 4x4 spills but luckily we use 4x2 | ||||||
|  |     __m512 vbias[COLS]; | ||||||
|  |  | ||||||
|  |     // [NOTE]: s8s8 igemm compensation in avx512-vnni | ||||||
|  |     // | ||||||
|  |     // avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate: | ||||||
|  |     // | ||||||
|  |     //   a * b = (a + 128) * b - 128 * b | ||||||
|  |     //   s   s       u       s    u    s | ||||||
|  |     // | ||||||
|  |     // 1) 128 * b is pre-computed when packing B to vnni formats | ||||||
|  |     // 2) a + 128 is fused when dynamically quantize A | ||||||
|  |     // | ||||||
|  |     auto loadc = [&](auto i) { | ||||||
|  |       vc[i] = _mm512_set1_epi32(0); | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(loadc); | ||||||
|  |  | ||||||
|  |     const int64_t K4 = K >> 2; | ||||||
|  |     const int64_t lda4 = lda >> 2; | ||||||
|  |     const int64_t ldb4 = ldb; // ldb * 4 >> 2; | ||||||
|  |     const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); | ||||||
|  |     const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B); | ||||||
|  |  | ||||||
|  |     auto compute = [&](auto i, int64_t k) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       if constexpr (col == 0) { | ||||||
|  |         va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); | ||||||
|  |       } | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); | ||||||
|  |         if constexpr (PREFETCH_SIZE_K > 0) { | ||||||
|  |           _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); | ||||||
|  |     }; | ||||||
|  |     for (int64_t k = 0; k < K4; ++k) { | ||||||
|  |       Unroll<ROWS * COLS>{}(compute, k); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto storec = [&](auto i) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       // load a scale | ||||||
|  |       if constexpr(col == 0) { | ||||||
|  |         vd0 = _mm512_set1_ps(As[row]); | ||||||
|  |       } | ||||||
|  |       // load b scale and vcomp per 2 vectors | ||||||
|  |       // also load bias if any | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         if constexpr (col % 2 == 0) { | ||||||
|  |           vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16); | ||||||
|  |           vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); | ||||||
|  |           vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); | ||||||
|  |           vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); | ||||||
|  |           if constexpr (has_bias) { | ||||||
|  |             vbias[col + 0] = _mm512_loadu_ps(bias + col * 16); | ||||||
|  |             vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16); | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // for COLS = 2, 4 use 512bit store | ||||||
|  |       if constexpr (col % 2 == 0) { | ||||||
|  |         __m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0])); | ||||||
|  |         __m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1])); | ||||||
|  |         if constexpr (has_bias) { | ||||||
|  |           vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]); | ||||||
|  |           vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]); | ||||||
|  |         } else { | ||||||
|  |           vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]); | ||||||
|  |           vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         _mm512_storeu_si512( | ||||||
|  |             reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||||
|  |             (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0))); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(storec); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE)                          \ | ||||||
|  |     tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply(         \ | ||||||
|  |         A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ | ||||||
|  |         As + mb_start, Bs + nb_start, Bcomp + nb_start,                      \ | ||||||
|  |         has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); | ||||||
|  |  | ||||||
|  | template <typename scalar_t, bool has_bias> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const uint8_t* __restrict__ A, | ||||||
|  |     const int8_t* __restrict__ B, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     int32_t* __restrict__ Ctmp, | ||||||
|  |     const float* __restrict__ As, | ||||||
|  |     const float* __restrict__ Bs, | ||||||
|  |     const float* __restrict__ bias, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc, | ||||||
|  |     bool brg) { | ||||||
|  |  | ||||||
|  |   // B compensation | ||||||
|  |   const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K); | ||||||
|  |  | ||||||
|  |   // pattern: 1-4-16 | ||||||
|  |   constexpr int64_t BLOCK_M = 4; | ||||||
|  |   constexpr int64_t BLOCK_N = 64; | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |   for (int64_t mb = 0; mb < MB; ++mb) { | ||||||
|  |     int64_t mb_start = mb * BLOCK_M; | ||||||
|  |     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||||
|  |     for (int64_t nb = 0; nb < NB; ++nb) { | ||||||
|  |       int64_t nb_start = nb * BLOCK_N; | ||||||
|  |       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||||
|  |  | ||||||
|  |       switch(mb_size << 4 | nb_size >> 4) { | ||||||
|  |         // mb_size = 1 | ||||||
|  |         case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; | ||||||
|  |         case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; | ||||||
|  |         // mb_size = 2 | ||||||
|  |         case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; | ||||||
|  |         case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; | ||||||
|  |         // mb_size = 3 | ||||||
|  |         case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; | ||||||
|  |         case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; | ||||||
|  |         // mb_size = 4 | ||||||
|  |         case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; | ||||||
|  |         case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; | ||||||
|  |         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<typename scalar_t> | ||||||
|  | void int8_scaled_mm_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ out, | ||||||
|  |     const uint8_t* __restrict__ mat1, | ||||||
|  |     const int8_t* __restrict__ mat2, | ||||||
|  |     const float* __restrict__ scales1, | ||||||
|  |     const float* __restrict__ scales2, | ||||||
|  |     const float* __restrict__ bias, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K) { | ||||||
|  |  | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m(); | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |  | ||||||
|  |   // TODO: brgemm u8s8 depends on PyTorch 2.7 release. | ||||||
|  |   const bool use_brgemm = false; | ||||||
|  |  | ||||||
|  |   // K + 4 after compensation | ||||||
|  |   const int64_t packed_row_size = get_row_size<int8_t>(K); | ||||||
|  |  | ||||||
|  |   AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { | ||||||
|  |     at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |       int64_t mb{0}, nb{0}; | ||||||
|  |       data_index_init(begin, mb, MB, nb, NB); | ||||||
|  |  | ||||||
|  |       // for brgemm, use int32_t for accumulate | ||||||
|  |       alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; | ||||||
|  |  | ||||||
|  |       for (int i = begin; i < end; ++i) { | ||||||
|  |         UNUSED(i); | ||||||
|  |         int mb_start = mb * BLOCK_M; | ||||||
|  |         int mb_size = std::min(M - mb_start, BLOCK_M); | ||||||
|  |         int nb_start = nb * BLOCK_N; | ||||||
|  |         int nb_size = std::min(N - nb_start, BLOCK_N); | ||||||
|  |  | ||||||
|  |         tinygemm_kernel<scalar_t, has_bias>( | ||||||
|  |             /*   A */ mat1 + mb_start * K, | ||||||
|  |             /*   B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, | ||||||
|  |             /*   C */ out + mb_start * N + nb_start, | ||||||
|  |             /* Ctmp*/ Ctmp, | ||||||
|  |             /*  As */ scales1 + mb_start, | ||||||
|  |             /*  Bs */ scales2 + nb_start, | ||||||
|  |             /* bias*/ bias + nb_start, | ||||||
|  |             /*   M */ mb_size, | ||||||
|  |             /*   N */ nb_size, | ||||||
|  |             /*   K */ K, | ||||||
|  |             /* lda */ K, | ||||||
|  |             /* ldb */ nb_size, | ||||||
|  |             /* ldc */ N, | ||||||
|  |             /* brg */ use_brgemm); | ||||||
|  |  | ||||||
|  |         // move to the next index | ||||||
|  |         data_index_step(mb, MB, nb, NB); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       if (use_brgemm) { | ||||||
|  |         at::native::cpublas::brgemm_release(); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
|  |  | ||||||
|  | // tinygemm interface | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, | ||||||
|  |     int32_t* __restrict__ Ctmp,  const float* __restrict__ As, const float* __restrict__ Bs, | ||||||
|  |     int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { | ||||||
|  |   tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE)                                                     \ | ||||||
|  |     template void tinygemm_kernel<TYPE>(                                                        \ | ||||||
|  |         const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C,      \ | ||||||
|  |         int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \ | ||||||
|  |         int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) | ||||||
|  |  | ||||||
|  | INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); | ||||||
|  | INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); | ||||||
|  |  | ||||||
|  | std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) { | ||||||
|  |   RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A})); | ||||||
|  |  | ||||||
|  |   CHECK_LAST_DIM_CONTIGUOUS_INPUT(A); | ||||||
|  |   CHECK_DIM(2, A); | ||||||
|  |  | ||||||
|  |   int64_t M = A.size(0); | ||||||
|  |   int64_t K = A.size(1); | ||||||
|  |   int64_t lda = A.stride(0); | ||||||
|  |  | ||||||
|  |   const auto st = A.scalar_type(); | ||||||
|  |   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, | ||||||
|  |       "per_token_quant_int8: expect A to be bfloat16 or half."); | ||||||
|  |  | ||||||
|  |   auto Aq = at::empty({M, K}, A.options().dtype(at::kByte)); | ||||||
|  |   auto As = at::empty({M}, A.options().dtype(at::kFloat)); | ||||||
|  |  | ||||||
|  |   AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] { | ||||||
|  |     uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>(); | ||||||
|  |     float* __restrict__ As_data = As.data_ptr<float>(); | ||||||
|  |     const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>(); | ||||||
|  |  | ||||||
|  |     at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { | ||||||
|  |       for (int64_t m = begin; m < end; ++m) { | ||||||
|  |         quantize_row_int8<scalar_t>( | ||||||
|  |             Aq_data + m * K, | ||||||
|  |             As_data[m], | ||||||
|  |             A_data + m * lda, | ||||||
|  |             K); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |   }); | ||||||
|  |   return std::make_tuple(Aq, As); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // weight     :  static, per-channel, symmetric | ||||||
|  | // activation : dynamic,   per-token, symmetric | ||||||
|  | // | ||||||
|  | // mat1    : [M, K] | ||||||
|  | // mat2    : [N, K] | ||||||
|  | // scales1 : [M] | ||||||
|  | // scales2 : [N] | ||||||
|  | // bias    : [N] | ||||||
|  | // out     : [M, N] | ||||||
|  | // | ||||||
|  | at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, | ||||||
|  |     at::Tensor& scales1, at::Tensor& scales2, | ||||||
|  |     std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) { | ||||||
|  |   RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias})); | ||||||
|  |  | ||||||
|  |   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||||
|  |  | ||||||
|  |   CHECK_INPUT(mat1); | ||||||
|  |   CHECK_INPUT(mat2); | ||||||
|  |   CHECK_INPUT(scales1); | ||||||
|  |   CHECK_INPUT(scales2); | ||||||
|  |   CHECK_DIM(2, mat1); | ||||||
|  |   CHECK_DIM(2, mat2); | ||||||
|  |  | ||||||
|  |   int64_t M = mat1.size(0); | ||||||
|  |   int64_t N = mat2.size(0); | ||||||
|  |   int64_t K = mat1.size(1); | ||||||
|  |  | ||||||
|  |   // see [NOTE]: s8s8 igemm compensation in avx512-vnni | ||||||
|  |   CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); | ||||||
|  |   CHECK_EQ(scales1.numel(), M); | ||||||
|  |   CHECK_EQ(scales2.numel(), N); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8."); | ||||||
|  |   TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8."); | ||||||
|  |   TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat, | ||||||
|  |       "int8_scaled_mm: expect scales to be float32."); | ||||||
|  |  | ||||||
|  |   auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); | ||||||
|  |  | ||||||
|  |   const bool has_bias = bias.has_value(); | ||||||
|  |   const float* bias_data = nullptr; | ||||||
|  |   if (has_bias) { | ||||||
|  |     CHECK_EQ(bias.value().size(0), N); | ||||||
|  |     bias_data = bias.value().data_ptr<float>(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] { | ||||||
|  |     int8_scaled_mm_kernel_impl<scalar_t>( | ||||||
|  |         out.data_ptr<scalar_t>(), | ||||||
|  |         mat1.data_ptr<uint8_t>(), | ||||||
|  |         packed_w.data_ptr<int8_t>(), | ||||||
|  |         scales1.data_ptr<float>(), | ||||||
|  |         scales2.data_ptr<float>(), | ||||||
|  |         bias_data, | ||||||
|  |         M, | ||||||
|  |         N, | ||||||
|  |         K); | ||||||
|  |   }); | ||||||
|  |   return out; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` | ||||||
|  | at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, | ||||||
|  |     const std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) { | ||||||
|  |   RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias})); | ||||||
|  |  | ||||||
|  |   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||||
|  |  | ||||||
|  |   CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); | ||||||
|  |   CHECK_INPUT(mat2); | ||||||
|  |   CHECK_INPUT(scales2); | ||||||
|  |   CHECK_DIM(2, mat1); | ||||||
|  |   CHECK_DIM(2, mat2); | ||||||
|  |  | ||||||
|  |   int64_t M = mat1.size(0); | ||||||
|  |   int64_t N = mat2.size(0); | ||||||
|  |   int64_t K = mat1.size(1); | ||||||
|  |   int64_t lda = mat1.stride(0); | ||||||
|  |  | ||||||
|  |   // see [NOTE]: s8s8 igemm compensation in avx512-vnni | ||||||
|  |   CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); | ||||||
|  |   CHECK_EQ(scales2.numel(), N); | ||||||
|  |  | ||||||
|  |   const auto st = mat1.scalar_type(); | ||||||
|  |   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, | ||||||
|  |       "int8_scaled_mm_with_quant: expect A to be bfloat16 or half."); | ||||||
|  |   TORCH_CHECK(st == out_dtype, | ||||||
|  |       "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype."); | ||||||
|  |   TORCH_CHECK(mat2.scalar_type() == at::kChar, | ||||||
|  |       "int8_scaled_mm_with_quant: expect mat2 to be int8."); | ||||||
|  |   TORCH_CHECK(scales2.scalar_type() == at::kFloat, | ||||||
|  |       "int8_scaled_mm_with_quant: expect scales to be float32."); | ||||||
|  |  | ||||||
|  |   const int64_t buffer_size = M * K + M * sizeof(float); | ||||||
|  |   auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte)); | ||||||
|  |   auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); | ||||||
|  |  | ||||||
|  |   const bool has_bias = bias.has_value(); | ||||||
|  |   const float* bias_data = nullptr; | ||||||
|  |   if (has_bias) { | ||||||
|  |     CHECK_EQ(bias.value().size(0), N); | ||||||
|  |     bias_data = bias.value().data_ptr<float>(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] { | ||||||
|  |     uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>(); | ||||||
|  |     float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K)); | ||||||
|  |     const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>(); | ||||||
|  |  | ||||||
|  |     at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { | ||||||
|  |       for (int64_t m = begin; m < end; ++m) { | ||||||
|  |         quantize_row_int8<scalar_t>( | ||||||
|  |             Aq_data + m * K, | ||||||
|  |             As_data[m], | ||||||
|  |             A_data + m * lda, | ||||||
|  |             K); | ||||||
|  |       } | ||||||
|  |     }); | ||||||
|  |  | ||||||
|  |     int8_scaled_mm_kernel_impl<scalar_t>( | ||||||
|  |         out.data_ptr<scalar_t>(), | ||||||
|  |         Aq_data, | ||||||
|  |         packed_w.data_ptr<int8_t>(), | ||||||
|  |         As_data, | ||||||
|  |         scales2.data_ptr<float>(), | ||||||
|  |         bias_data, | ||||||
|  |         M, | ||||||
|  |         N, | ||||||
|  |         K); | ||||||
|  |   }); | ||||||
|  |   return out; | ||||||
|  | } | ||||||
							
								
								
									
										1330
									
								
								csrc/cpu/sgl-kernels/moe.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1330
									
								
								csrc/cpu/sgl-kernels/moe.cpp
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										502
									
								
								csrc/cpu/sgl-kernels/moe_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										502
									
								
								csrc/cpu/sgl-kernels/moe_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,502 @@ | |||||||
|  | // Adapted from | ||||||
|  | // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||||
|  |  | ||||||
|  | #include "common.h" | ||||||
|  | #include "gemm.h" | ||||||
|  | #include "vec.h" | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { | ||||||
|  |   using Vec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   // no remainder | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (int64_t d = 0; d < size; d += Vec::size()) { | ||||||
|  |     Vec data = Vec::loadu(input + d); | ||||||
|  |     data.store(out + d); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |   const fVec weight_vec = fVec(weight); | ||||||
|  |   int64_t d; | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     bVec x = bVec::loadu(input + d); | ||||||
|  |     fVec x0, x1; | ||||||
|  |     std::tie(x0, x1) = at::vec::convert_to_float(x); | ||||||
|  |     x0 = x0 * weight_vec; | ||||||
|  |     x1 = x1 * weight_vec; | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d] * weight); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // acc from [topk, K] to [K] | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |   if (topk == 1) { | ||||||
|  |     // do copy for topk = 1 | ||||||
|  |     copy_stub(out, input, K); | ||||||
|  |   } else { | ||||||
|  |     // do sum for topk != 1 | ||||||
|  |     int64_t d; | ||||||
|  |     #pragma GCC unroll 4 | ||||||
|  |     for (d = 0; d <= K - kVecSize; d += kVecSize) { | ||||||
|  |       fVec sum_fvec0 = fVec(0.f); | ||||||
|  |       fVec sum_fvec1 = fVec(0.f); | ||||||
|  |       for (int t = 0; t < topk; ++t) { | ||||||
|  |         bVec x_bvec = bVec::loadu(input + t * K + d); | ||||||
|  |         fVec x_fvec0, x_fvec1; | ||||||
|  |         std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); | ||||||
|  |  | ||||||
|  |         sum_fvec0 += x_fvec0; | ||||||
|  |         sum_fvec1 += x_fvec1; | ||||||
|  |       } | ||||||
|  |       bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1); | ||||||
|  |       out_bvec.store(out + d); | ||||||
|  |     } | ||||||
|  |     for (; d < K; ++d) { | ||||||
|  |       float sum_val = 0.f; | ||||||
|  |       for (int t = 0; t < topk; ++t) { | ||||||
|  |         sum_val += static_cast<float>(input[t * K + d]); | ||||||
|  |       } | ||||||
|  |       out[d] = static_cast<scalar_t>(sum_val); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // out = input + input2 * scale | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void add_mul_stub( | ||||||
|  |     scalar_t* __restrict__ out, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const scalar_t* __restrict__ input2, | ||||||
|  |     float scale, | ||||||
|  |     int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |   const fVec s_vec = fVec(scale); | ||||||
|  |  | ||||||
|  |   int64_t d; | ||||||
|  | #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     bVec x_bvec = bVec::loadu(input + d); | ||||||
|  |     fVec x0, x1; | ||||||
|  |     std::tie(x0, x1) = at::vec::convert_to_float(x_bvec); | ||||||
|  |  | ||||||
|  |     bVec y_bvec = bVec::loadu(input2 + d); | ||||||
|  |     fVec y0, y1; | ||||||
|  |     std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); | ||||||
|  |  | ||||||
|  |     x0 = x0 + y0 * s_vec; | ||||||
|  |     x1 = x1 + y1 * s_vec; | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void silu_and_mul_stub( | ||||||
|  |     scalar_t* __restrict__ out, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const scalar_t* __restrict__ input2, | ||||||
|  |     int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   const fVec one = fVec(1.f); | ||||||
|  |  | ||||||
|  |   // no remainder | ||||||
|  | #pragma GCC unroll 4 | ||||||
|  |   for (int64_t d = 0; d < size; d += bVec::size()) { | ||||||
|  |     bVec x = bVec::loadu(input + d); | ||||||
|  |     fVec x0, x1; | ||||||
|  |     std::tie(x0, x1) = at::vec::convert_to_float(x); | ||||||
|  |     bVec y = bVec::loadu(input2 + d); | ||||||
|  |     fVec y0, y1; | ||||||
|  |     std::tie(y0, y1) = at::vec::convert_to_float(y); | ||||||
|  |     x0 = x0 / (one + x0.neg().exp_u20()); | ||||||
|  |     x1 = x1 / (one + x1.neg().exp_u20()); | ||||||
|  |     x0 = x0 * y0; | ||||||
|  |     x1 = x1 * y1; | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void fused_experts_fp8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic0, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     scalar_t* __restrict__ ic2, | ||||||
|  |     scalar_t* __restrict__ A_tmp, | ||||||
|  |     scalar_t* __restrict__ B_tmp, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     int64_t block_size_N, | ||||||
|  |     int64_t block_size_K, | ||||||
|  |     const float* __restrict__ topk_weights, | ||||||
|  |     const int32_t* __restrict__ sorted_ids, | ||||||
|  |     const int32_t* __restrict__ expert_ids, | ||||||
|  |     const int32_t* __restrict__ offsets, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t E, | ||||||
|  |     int64_t topk, | ||||||
|  |     int64_t num_tokens_post_pad) { | ||||||
|  |  | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m(); | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |  | ||||||
|  |   // stage 1: intermediate_cache0 = hidden_states @ w1 | ||||||
|  |   const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(2 * N, BLOCK_N); | ||||||
|  |   int64_t scale_size_N = div_up(2 * N, block_size_N); | ||||||
|  |   int64_t scale_size_K = div_up(K, block_size_K); | ||||||
|  |   int64_t blocks_n_per_group = block_size_N / BLOCK_N; | ||||||
|  |  | ||||||
|  |   const int64_t stride_e = 2 * N * K; | ||||||
|  |   const int64_t stride_n = K; | ||||||
|  |  | ||||||
|  |   // here we only parallel on half of 2N to fuse silu_and_mul with gemm | ||||||
|  |   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     // get local pointers | ||||||
|  |     int tid = at::get_thread_num(); | ||||||
|  |     scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; | ||||||
|  |  | ||||||
|  |     bool is_brgemm_used = false; | ||||||
|  |  | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB; | ||||||
|  |       int64_t nb = i % NB; | ||||||
|  |  | ||||||
|  |       int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); | ||||||
|  |  | ||||||
|  |       // B shape [K, n_size] in vnni format | ||||||
|  |       int32_t expert_id = expert_ids[mb]; | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n; | ||||||
|  |       const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; | ||||||
|  |  | ||||||
|  |       // 1.a load A | ||||||
|  |       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||||
|  |       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||||
|  |  | ||||||
|  |       const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size); | ||||||
|  |       is_brgemm_used = is_brgemm_used || use_brgemm; | ||||||
|  |  | ||||||
|  |       for (int64_t m = 0; m < m_size; ++m) { | ||||||
|  |         int32_t index = A_ids[m] / topk; | ||||||
|  |         copy_stub(A + m * K, input + index * K, K); | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       const int64_t offset = offsets[mb]; | ||||||
|  |       tinygemm_kernel<scalar_t>( | ||||||
|  |           /*   A            */ A, | ||||||
|  |           /*   B            */ B, | ||||||
|  |           /*   C            */ ic0 + offset * 2 * N + nb * BLOCK_N, | ||||||
|  |           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||||
|  |           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||||
|  |           /*   scale        */ Bs, | ||||||
|  |           /*   M            */ m_size, | ||||||
|  |           /*   N            */ n_size, | ||||||
|  |           /*   K            */ K, | ||||||
|  |           /*   lda          */ K, | ||||||
|  |           /*   ldb          */ n_size, | ||||||
|  |           /*   ldc          */ 2 * N, | ||||||
|  |           /*   brg          */ use_brgemm, | ||||||
|  |           /*   block_size_K */ block_size_K); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (is_brgemm_used) { | ||||||
|  |       at::native::cpublas::brgemm_release(); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) | ||||||
|  |   at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       silu_and_mul_stub( | ||||||
|  |           ic1 + m * N, | ||||||
|  |           ic0 + m * 2 * N, | ||||||
|  |           ic0 + m * 2 * N + N, | ||||||
|  |           N); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||||
|  |   //   w2 : [E, K, N] as [E, OC, IC] | ||||||
|  |   const int64_t OC = K;  // rename K as OC | ||||||
|  |   const int64_t IC = N;  // rename N as IC | ||||||
|  |   const int64_t MB2 = MB; | ||||||
|  |   const int64_t NB2 = div_up(OC, BLOCK_N); | ||||||
|  |   scale_size_N = div_up(K, block_size_N); | ||||||
|  |   scale_size_K = div_up(N, block_size_K); | ||||||
|  |   const int64_t stride_e2 = OC * IC; | ||||||
|  |   const int64_t stride_oc = IC; | ||||||
|  |  | ||||||
|  |   // parallel on [MB2, NB2] | ||||||
|  |   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     int tid = at::get_thread_num(); | ||||||
|  |     alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; | ||||||
|  |  | ||||||
|  |     bool is_brgemm_used = false; | ||||||
|  |  | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB2; | ||||||
|  |       int64_t nb = i % NB2; | ||||||
|  |  | ||||||
|  |       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||||
|  |       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||||
|  |  | ||||||
|  |       const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size); | ||||||
|  |       is_brgemm_used = is_brgemm_used || use_brgemm; | ||||||
|  |  | ||||||
|  |       // A ptr from ic1 of [M * topk, N] in sorted order | ||||||
|  |       // so as to avoid copy A to tmp buffer again | ||||||
|  |       const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; | ||||||
|  |       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||||
|  |  | ||||||
|  |       // B shape [IC, n_size] in vnni format | ||||||
|  |       int32_t expert_id = expert_ids[mb]; | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; | ||||||
|  |       const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; | ||||||
|  |  | ||||||
|  |       tinygemm_kernel<scalar_t>( | ||||||
|  |           /*   A            */ A, | ||||||
|  |           /*   B            */ B, | ||||||
|  |           /*   C            */ C, | ||||||
|  |           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||||
|  |           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||||
|  |           /*   scale        */ Bs, | ||||||
|  |           /*   M            */ m_size, | ||||||
|  |           /*   N            */ n_size, | ||||||
|  |           /*   K            */ IC, | ||||||
|  |           /*   lda          */ IC, | ||||||
|  |           /*   ldb          */ n_size, | ||||||
|  |           /*   ldc          */ BLOCK_N, | ||||||
|  |           /*   brg          */ use_brgemm, | ||||||
|  |           /*   block_size_K */ block_size_K); | ||||||
|  |  | ||||||
|  |       // 2.b copy from C to ic2 in original order | ||||||
|  |       //   and also mul topk_weights in float32 | ||||||
|  |       for (int64_t m = 0; m < m_size; ++m) { | ||||||
|  |         int32_t index = A_ids[m]; | ||||||
|  |         float weight = topk_weights[index]; | ||||||
|  |         copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (is_brgemm_used) { | ||||||
|  |       at::native::cpublas::brgemm_release(); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 3: out = intermediate_cache2.sum(dim=1) | ||||||
|  |   //   from [M, topk, K] to [M, K] | ||||||
|  |   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       sum_stub(output + m * K, ic2 + m * topk * K, topk, K); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE)             \ | ||||||
|  |   template void fused_experts_fp8_kernel_impl<TYPE>(   \ | ||||||
|  |       TYPE* __restrict__ output,                       \ | ||||||
|  |       TYPE* __restrict__ ic0,                          \ | ||||||
|  |       TYPE* __restrict__ ic1,                          \ | ||||||
|  |       TYPE* __restrict__ ic2,                          \ | ||||||
|  |       TYPE* __restrict__ A_tmp,                        \ | ||||||
|  |       TYPE* __restrict__ B_tmp,                        \ | ||||||
|  |       float* __restrict__ C_tmp,                       \ | ||||||
|  |       const TYPE* __restrict__ input,                  \ | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ packed_w1, \ | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ packed_w2, \ | ||||||
|  |       const float* __restrict__ w1s,                   \ | ||||||
|  |       const float* __restrict__ w2s,                   \ | ||||||
|  |       int64_t block_size_N,                            \ | ||||||
|  |       int64_t block_size_K,                            \ | ||||||
|  |       const float* __restrict__ topk_weights,          \ | ||||||
|  |       const int32_t* __restrict__ sorted_ids,          \ | ||||||
|  |       const int32_t* __restrict__ expert_ids,          \ | ||||||
|  |       const int32_t* __restrict__ offsets,             \ | ||||||
|  |       int64_t M,                                       \ | ||||||
|  |       int64_t N,                                       \ | ||||||
|  |       int64_t K,                                       \ | ||||||
|  |       int64_t E,                                       \ | ||||||
|  |       int64_t topk,                                    \ | ||||||
|  |       int64_t num_tokens_post_pad) | ||||||
|  |  | ||||||
|  | INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16); | ||||||
|  | INSTANTIATE_MOE_FP8_TEMPLATE(at::Half); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void shared_expert_fp8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic0, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     scalar_t* __restrict__ B_tmp, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||||
|  |     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     int64_t block_size_N, | ||||||
|  |     int64_t block_size_K, | ||||||
|  |     const scalar_t* __restrict__ fused_experts_out, | ||||||
|  |     float routed_scaling_factor, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K) { | ||||||
|  |  | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m(); | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |  | ||||||
|  |   // stage 1: intermediate_cache0 = hidden_states @ w1 | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(2 * N, BLOCK_N); | ||||||
|  |   int64_t scale_size_K = div_up(K, block_size_K); | ||||||
|  |   int64_t blocks_n_per_group = block_size_N / BLOCK_N; | ||||||
|  |  | ||||||
|  |   const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M); | ||||||
|  |  | ||||||
|  |   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     int tid = at::get_thread_num(); | ||||||
|  |  | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB; | ||||||
|  |       int64_t nb = i % NB; | ||||||
|  |       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||||
|  |       int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); | ||||||
|  |  | ||||||
|  |       tinygemm_kernel<scalar_t>( | ||||||
|  |           /*   A            */ input + mb * BLOCK_M * K, | ||||||
|  |           /*   B            */ packed_w1 + nb * BLOCK_N * K, | ||||||
|  |           /*   C            */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, | ||||||
|  |           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||||
|  |           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||||
|  |           /*   scale        */ w1s + (nb / blocks_n_per_group) * scale_size_K, | ||||||
|  |           /*   M            */ m_size, | ||||||
|  |           /*   N            */ n_size, | ||||||
|  |           /*   K            */ K, | ||||||
|  |           /*   lda          */ K, | ||||||
|  |           /*   ldb          */ n_size, | ||||||
|  |           /*   ldc          */ 2 * N, | ||||||
|  |           /*   brg          */ use_brgemm, | ||||||
|  |           /*   block_size_K */ block_size_K); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (use_brgemm) { | ||||||
|  |       at::native::cpublas::brgemm_release(); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) | ||||||
|  |   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       silu_and_mul_stub( | ||||||
|  |           ic1 + m * N, | ||||||
|  |           ic0 + m * 2 * N, | ||||||
|  |           ic0 + m * 2 * N + N, | ||||||
|  |           N); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||||
|  |   //   w2 : [K, N] as [OC, IC] | ||||||
|  |   const int64_t OC = K;  // rename K as OC | ||||||
|  |   const int64_t IC = N;  // rename N as IC | ||||||
|  |   const int64_t MB2 = MB; | ||||||
|  |   const int64_t NB2 = div_up(K, BLOCK_N); | ||||||
|  |   scale_size_K = div_up(N, block_size_K); | ||||||
|  |  | ||||||
|  |   // parallel on [MB2, NB2] | ||||||
|  |   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     int tid = at::get_thread_num(); | ||||||
|  |     alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; | ||||||
|  |  | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB2; | ||||||
|  |       int64_t nb = i % NB2; | ||||||
|  |       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||||
|  |       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||||
|  |  | ||||||
|  |       // 2.a gemm: C = A @ B | ||||||
|  |       tinygemm_kernel<scalar_t>( | ||||||
|  |           /*   A            */ ic1 + mb * BLOCK_M * N, | ||||||
|  |           /*   B            */ packed_w2 + nb * BLOCK_N * N, | ||||||
|  |           /*   C            */ C, | ||||||
|  |           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||||
|  |           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||||
|  |           /*   scale        */ w2s + (nb / blocks_n_per_group) * scale_size_K, | ||||||
|  |           /*   M            */ m_size, | ||||||
|  |           /*   N            */ n_size, | ||||||
|  |           /*   K            */ IC, | ||||||
|  |           /*   lda          */ IC, | ||||||
|  |           /*   ldb          */ n_size, | ||||||
|  |           /*   ldc          */ BLOCK_N, | ||||||
|  |           /*   brg          */ use_brgemm, | ||||||
|  |           /*   block_size_K */ block_size_K); | ||||||
|  |  | ||||||
|  |       // 2.b copy from C to output and add fused_experts_out | ||||||
|  |       scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; | ||||||
|  |       const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; | ||||||
|  |       for (int64_t m = 0; m < m_size; ++m) { | ||||||
|  |         add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   if (use_brgemm) { | ||||||
|  |     at::native::cpublas::brgemm_release(); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE)   \ | ||||||
|  |   template void shared_expert_fp8_kernel_impl<TYPE>(   \ | ||||||
|  |       TYPE* __restrict__ output,                       \ | ||||||
|  |       TYPE* __restrict__ ic0,                          \ | ||||||
|  |       TYPE* __restrict__ ic1,                          \ | ||||||
|  |       TYPE* __restrict__ B_tmp,                        \ | ||||||
|  |       float* __restrict__ C_tmp,                       \ | ||||||
|  |       const TYPE* __restrict__ input,                  \ | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ packed_w1, \ | ||||||
|  |       const at::Float8_e4m3fn* __restrict__ packed_w2, \ | ||||||
|  |       const float* __restrict__ w1s,                   \ | ||||||
|  |       const float* __restrict__ w2s,                   \ | ||||||
|  |       int64_t block_size_N,                            \ | ||||||
|  |       int64_t block_size_K,                            \ | ||||||
|  |       const TYPE* __restrict__ fused_experts_out,      \ | ||||||
|  |       float routed_scaling_factor,                     \ | ||||||
|  |       int64_t M,                                       \ | ||||||
|  |       int64_t N,                                       \ | ||||||
|  |       int64_t K) | ||||||
|  |  | ||||||
|  | INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16); | ||||||
|  | INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half); | ||||||
							
								
								
									
										769
									
								
								csrc/cpu/sgl-kernels/moe_int8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										769
									
								
								csrc/cpu/sgl-kernels/moe_int8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,769 @@ | |||||||
|  | // Adapted from | ||||||
|  | // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||||
|  |  | ||||||
|  | #include "common.h" | ||||||
|  | #include "vec.h" | ||||||
|  | #include "gemm.h" | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { | ||||||
|  |   using Vec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   // no remainder | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (int64_t d = 0; d < size; d += Vec::size()) { | ||||||
|  |     Vec data = Vec::loadu(input + d); | ||||||
|  |     data.store(out + d); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { | ||||||
|  |   // size might be 64x + 32 | ||||||
|  |   std::memcpy(out, input, size * sizeof(uint8_t)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |   const fVec weight_vec = fVec(weight); | ||||||
|  |   int64_t d; | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     fVec data0 = fVec::loadu(input + d) * weight_vec; | ||||||
|  |     fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d] * weight); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // acc from [topk, K] to [K] | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |   if (topk == 1) { | ||||||
|  |     // do copy for topk = 1 | ||||||
|  |     copy_stub(out, input, K); | ||||||
|  |   } else { | ||||||
|  |     // do sum for topk != 1 | ||||||
|  |     int64_t d; | ||||||
|  |     #pragma GCC unroll 4 | ||||||
|  |     for (d = 0; d <= K - kVecSize; d += kVecSize) { | ||||||
|  |       fVec sum_fvec0 = fVec(0.f); | ||||||
|  |       fVec sum_fvec1 = fVec(0.f); | ||||||
|  |       for (int t = 0; t < topk; ++t) { | ||||||
|  |         bVec x_bvec = bVec::loadu(input + t * K + d); | ||||||
|  |         fVec x_fvec0, x_fvec1; | ||||||
|  |         std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); | ||||||
|  |  | ||||||
|  |         sum_fvec0 += x_fvec0; | ||||||
|  |         sum_fvec1 += x_fvec1; | ||||||
|  |       } | ||||||
|  |       bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1); | ||||||
|  |       out_bvec.store(out + d); | ||||||
|  |     } | ||||||
|  |     for (; d < K; ++d) { | ||||||
|  |       float sum_val = 0.f; | ||||||
|  |       for (int t = 0; t < topk; ++t) { | ||||||
|  |         sum_val += static_cast<float>(input[t * K + d]); | ||||||
|  |       } | ||||||
|  |       out[d] = static_cast<scalar_t>(sum_val); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // out = input + input2 * scale | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, | ||||||
|  |     const scalar_t* __restrict__ input2, float scale, int64_t size) { | ||||||
|  |  | ||||||
|  |   using bVec = at::vec::Vectorized<scalar_t>; | ||||||
|  |   using fVec = at::vec::Vectorized<float>; | ||||||
|  |   constexpr int kVecSize = bVec::size(); | ||||||
|  |   const fVec s_vec = fVec(scale); | ||||||
|  |   int64_t d; | ||||||
|  |   #pragma GCC unroll 4 | ||||||
|  |   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||||
|  |     fVec x0 = fVec::loadu(input + d); | ||||||
|  |     fVec x1 = fVec::loadu(input + d + fVec::size()); | ||||||
|  |  | ||||||
|  |     bVec y_bvec = bVec::loadu(input2 + d); | ||||||
|  |     fVec y0, y1; | ||||||
|  |     std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); | ||||||
|  |  | ||||||
|  |     x0 = x0 + y0 * s_vec; | ||||||
|  |     x1 = x1 + y1 * s_vec; | ||||||
|  |     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||||
|  |     out_vec.store(out + d); | ||||||
|  |   } | ||||||
|  |   for (; d < size; ++d) { | ||||||
|  |     out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /// gemm for w13 | ||||||
|  | template <typename scalar_t, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_vnni { | ||||||
|  |   static inline void apply( | ||||||
|  |       const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C, | ||||||
|  |       const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, | ||||||
|  |       const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, | ||||||
|  |       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  | template <int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> { | ||||||
|  |   static inline void apply( | ||||||
|  |       const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C, | ||||||
|  |       const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, | ||||||
|  |       const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, | ||||||
|  |       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |  | ||||||
|  |     constexpr int ROWS = BLOCK_M; | ||||||
|  |     constexpr int COLS = BLOCK_N / 16; | ||||||
|  |     static_assert(COLS % 2 == 0); | ||||||
|  |  | ||||||
|  |     __m512i va; | ||||||
|  |     __m512i vb0[COLS]; | ||||||
|  |     __m512i vb1[COLS]; | ||||||
|  |     __m512i vc0[ROWS * COLS]; | ||||||
|  |     __m512i vc1[ROWS * COLS]; | ||||||
|  |     __m512i vcomp0[COLS]; | ||||||
|  |     __m512i vcomp1[COLS]; | ||||||
|  |     __m512  was; | ||||||
|  |     __m512  vbs0[COLS]; | ||||||
|  |     __m512  vbs1[COLS]; | ||||||
|  |  | ||||||
|  |     auto loadc = [&](auto i) { | ||||||
|  |       vc0[i] = _mm512_set1_epi32(0); | ||||||
|  |       vc1[i] = _mm512_set1_epi32(0); | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(loadc); | ||||||
|  |  | ||||||
|  |     const int64_t K4 = K >> 2; | ||||||
|  |     const int64_t lda4 = lda >> 2; | ||||||
|  |     const int64_t ldb4 = ldb; // ldb * 4 >> 2; | ||||||
|  |     const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); | ||||||
|  |     const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0); | ||||||
|  |     const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1); | ||||||
|  |  | ||||||
|  |     auto compute = [&](auto i, int64_t k) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       if constexpr (col == 0) { | ||||||
|  |         va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); | ||||||
|  |       } | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); | ||||||
|  |         vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); | ||||||
|  |       } | ||||||
|  |       vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); | ||||||
|  |       vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); | ||||||
|  |     }; | ||||||
|  |     for (int64_t k = 0; k < K4; ++k) { | ||||||
|  |       Unroll<ROWS * COLS>{}(compute, k); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto scalec = [&](auto i) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       // load a scale | ||||||
|  |       if constexpr(col == 0) { | ||||||
|  |         was = _mm512_set1_ps(As[row]); | ||||||
|  |       } | ||||||
|  |       // load b scale and vcomp | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); | ||||||
|  |         vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); | ||||||
|  |         vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); | ||||||
|  |         vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); | ||||||
|  |       } | ||||||
|  |       __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); | ||||||
|  |       __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); | ||||||
|  |       vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col])); | ||||||
|  |       vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col])); | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(scalec); | ||||||
|  |  | ||||||
|  |     using Vec = at::vec::Vectorized<float>; | ||||||
|  |     const Vec one = Vec(1.f); | ||||||
|  |     auto storec = [&](auto i) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |       // for COLS = 2, 4 use 512bit store | ||||||
|  |       if constexpr (col % 2 == 0) { | ||||||
|  |         Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); | ||||||
|  |         Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); | ||||||
|  |         Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); | ||||||
|  |         Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); | ||||||
|  |         // silu | ||||||
|  |         x0 = x0 / (one + x0.neg().exp_u20()); | ||||||
|  |         x1 = x1 / (one + x1.neg().exp_u20()); | ||||||
|  |         // mul | ||||||
|  |         x0 = x0 * y0; | ||||||
|  |         x1 = x1 * y1; | ||||||
|  |  | ||||||
|  |         _mm512_storeu_si512( | ||||||
|  |             reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||||
|  |             (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); | ||||||
|  |         } | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(storec); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE)                        \ | ||||||
|  |     tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply(                 \ | ||||||
|  |         A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4,            \ | ||||||
|  |         C + mb_start * ldc + nb_start, As + mb_start,                        \ | ||||||
|  |         Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ | ||||||
|  |         K, lda, ldb, ldc); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const uint8_t* __restrict__ A, | ||||||
|  |     const int8_t* __restrict__ B0, | ||||||
|  |     const int8_t* __restrict__ B1, | ||||||
|  |     scalar_t* __restrict__ C, | ||||||
|  |     const float* __restrict__ As, | ||||||
|  |     const float* __restrict__ Bs0, | ||||||
|  |     const float* __restrict__ Bs1, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc) { | ||||||
|  |  | ||||||
|  |   const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K); | ||||||
|  |   const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K); | ||||||
|  |  | ||||||
|  |   // pattern: 1-(2+2)-(8+8) | ||||||
|  |   constexpr int64_t BLOCK_M = 4; | ||||||
|  |   constexpr int64_t BLOCK_N = 32; | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |   for (int mb = 0; mb < MB; ++mb) { | ||||||
|  |     int64_t mb_start = mb * BLOCK_M; | ||||||
|  |     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||||
|  |     for (int64_t nb = 0; nb < NB; ++nb) { | ||||||
|  |       int64_t nb_start = nb * BLOCK_N; | ||||||
|  |       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||||
|  |  | ||||||
|  |       switch(mb_size << 4 | nb_size >> 4) { | ||||||
|  |         case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break; | ||||||
|  |         case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break; | ||||||
|  |         case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break; | ||||||
|  |         case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break; | ||||||
|  |         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /// gemm for w2 | ||||||
|  | template <typename scalar_t, int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_vnni2 { | ||||||
|  |   static inline void apply( | ||||||
|  |       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, | ||||||
|  |       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||||
|  |       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  | template <int BLOCK_M, int BLOCK_N> | ||||||
|  | struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> { | ||||||
|  |   static inline void apply( | ||||||
|  |       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, | ||||||
|  |       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||||
|  |       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||||
|  |  | ||||||
|  |     constexpr int ROWS = BLOCK_M; | ||||||
|  |     constexpr int COLS = BLOCK_N / 16; | ||||||
|  |     static_assert(COLS % 2 == 0); | ||||||
|  |  | ||||||
|  |     __m512i va; | ||||||
|  |     __m512i vb[COLS]; | ||||||
|  |     __m512i vc[ROWS * COLS]; | ||||||
|  |     __m512i vcomp[COLS]; | ||||||
|  |     __m512  was; | ||||||
|  |     __m512  vbs[COLS]; | ||||||
|  |  | ||||||
|  |     auto loadc = [&](auto i) { | ||||||
|  |       vc[i] = _mm512_set1_epi32(0); | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(loadc); | ||||||
|  |  | ||||||
|  |     const int64_t K4 = K >> 2; | ||||||
|  |     const int64_t lda4 = lda >> 2; | ||||||
|  |     const int64_t ldb4 = ldb; // ldb * 4 >> 2; | ||||||
|  |     const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); | ||||||
|  |     const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B); | ||||||
|  |  | ||||||
|  |     auto compute = [&](auto i, int64_t k) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       if constexpr (col == 0) { | ||||||
|  |         va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); | ||||||
|  |       } | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); | ||||||
|  |       } | ||||||
|  |       vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); | ||||||
|  |     }; | ||||||
|  |     for (int64_t k = 0; k < K4; ++k) { | ||||||
|  |       Unroll<ROWS * COLS>{}(compute, k); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     auto storec = [&](auto i) { | ||||||
|  |       constexpr int row = i / COLS; | ||||||
|  |       constexpr int col = i % COLS; | ||||||
|  |  | ||||||
|  |       // load a scale | ||||||
|  |       if constexpr(col == 0) { | ||||||
|  |         was = _mm512_set1_ps(As[row]); | ||||||
|  |       } | ||||||
|  |       // load b scale and vcomp per 2 vectors | ||||||
|  |       // also load bias if any | ||||||
|  |       if constexpr (row == 0) { | ||||||
|  |         if constexpr (col % 2 == 0) { | ||||||
|  |           vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); | ||||||
|  |           vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); | ||||||
|  |           vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); | ||||||
|  |           vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); | ||||||
|  |       x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]); | ||||||
|  |       _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); | ||||||
|  |     }; | ||||||
|  |     Unroll<ROWS * COLS>{}(storec); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE)                       \ | ||||||
|  |     tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply(                \ | ||||||
|  |         A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ | ||||||
|  |         As + mb_start, Bs + nb_start, Bcomp + nb_start,                      \ | ||||||
|  |         K, lda, ldb, ldc); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void tinygemm_kernel( | ||||||
|  |     const uint8_t* __restrict__ A, | ||||||
|  |     const int8_t* __restrict__ B, | ||||||
|  |     float* __restrict__ C, | ||||||
|  |     const float* __restrict__ As, | ||||||
|  |     const float* __restrict__ Bs, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t lda, | ||||||
|  |     int64_t ldb, | ||||||
|  |     int64_t ldc) { | ||||||
|  |  | ||||||
|  |   // B compensation | ||||||
|  |   const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K); | ||||||
|  |  | ||||||
|  |   // pattern: 1-4-16 | ||||||
|  |   constexpr int64_t BLOCK_M = 4; | ||||||
|  |   constexpr int64_t BLOCK_N = 64; | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |   for (int64_t mb = 0; mb < MB; ++mb) { | ||||||
|  |     int64_t mb_start = mb * BLOCK_M; | ||||||
|  |     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||||
|  |     for (int64_t nb = 0; nb < NB; ++nb) { | ||||||
|  |       int64_t nb_start = nb * BLOCK_N; | ||||||
|  |       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||||
|  |  | ||||||
|  |       switch(mb_size << 4 | nb_size >> 4) { | ||||||
|  |         case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; | ||||||
|  |         case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; | ||||||
|  |         case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; | ||||||
|  |         case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; | ||||||
|  |         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void fused_experts_int8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     scalar_t* __restrict__ ic2, | ||||||
|  |     uint8_t* __restrict__ A_tmp, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     uint8_t* __restrict__ Aq_tmp, | ||||||
|  |     float* __restrict__ As_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const int8_t* __restrict__ packed_w1, | ||||||
|  |     const int8_t* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     const float* __restrict__ topk_weights, | ||||||
|  |     const int32_t* __restrict__ sorted_ids, | ||||||
|  |     const int32_t* __restrict__ expert_ids, | ||||||
|  |     const int32_t* __restrict__ offsets, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K, | ||||||
|  |     int64_t E, | ||||||
|  |     int64_t topk, | ||||||
|  |     int64_t num_tokens_post_pad) { | ||||||
|  |  | ||||||
|  |   // handle 2 tiles per block | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m(); | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |  | ||||||
|  |   // stage 0: quantize input to uint8, [M, K] | ||||||
|  |   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       quantize_row_int8<scalar_t>( | ||||||
|  |           Aq_tmp + m * K, | ||||||
|  |           As_tmp[m], | ||||||
|  |           input + m * K, | ||||||
|  |           K); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 1: intermediate_cache1 = silu(hidden_states @ w1) | ||||||
|  |   const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |  | ||||||
|  |   // strides for w1: [E, 2N, K] | ||||||
|  |   TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); | ||||||
|  |  | ||||||
|  |   // K and N are packed for int8 | ||||||
|  |   const int64_t packed_K = get_row_size<int8_t>(K); | ||||||
|  |   const int64_t packed_N = get_row_size<int8_t>(N); | ||||||
|  |  | ||||||
|  |   const int64_t stride_e = 2 * N * packed_K; | ||||||
|  |   const int64_t stride_n = packed_K; | ||||||
|  |   // here we only parallel on half of 2N to fuse silu_and_mul with gemm | ||||||
|  |   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     // get local pointers | ||||||
|  |     int tid = at::get_thread_num(); | ||||||
|  |     uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; | ||||||
|  |  | ||||||
|  |     alignas(64) float As[BLOCK_M]; | ||||||
|  |  | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB; | ||||||
|  |       int64_t nb = i % NB; | ||||||
|  |  | ||||||
|  |       // nb0 from top half and nb1 from bottom half | ||||||
|  |       int64_t nb0 = nb, nb1 = nb + NB; | ||||||
|  |       int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); | ||||||
|  |  | ||||||
|  |       // B shape [K, n_size] in vnni format | ||||||
|  |       int32_t expert_id = expert_ids[mb]; | ||||||
|  |       const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; | ||||||
|  |       const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; | ||||||
|  |       const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; | ||||||
|  |       const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; | ||||||
|  |  | ||||||
|  |       // 1.a load A | ||||||
|  |       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||||
|  |       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||||
|  |  | ||||||
|  |       for (int64_t m = 0; m < m_size; ++m) { | ||||||
|  |         int32_t index = A_ids[m] / topk; | ||||||
|  |         copy_stub(A + m * K, Aq_tmp + index * K, K); | ||||||
|  |         As[m] = As_tmp[index]; | ||||||
|  |       } | ||||||
|  |  | ||||||
|  |       // fused 1.b: silu_and_mul(A @ B0, A @ B1) | ||||||
|  |       const int64_t offset = offsets[mb]; | ||||||
|  |       tinygemm_kernel( | ||||||
|  |           /* A     */ A, | ||||||
|  |           /* B0    */ B0, | ||||||
|  |           /* B1    */ B1, | ||||||
|  |           /* C     */ ic1 + offset * N + nb * BLOCK_N, | ||||||
|  |           /* As    */ As, | ||||||
|  |           /* Bs0   */ Bs0, | ||||||
|  |           /* Bs1   */ Bs1, | ||||||
|  |           /* M     */ m_size, | ||||||
|  |           /* N     */ n_size, | ||||||
|  |           /* K     */ K, | ||||||
|  |           /* lda   */ K, | ||||||
|  |           /* ldb   */ n_size, | ||||||
|  |           /* ldc   */ N); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 1.5: quantize ic1 to uint8, [M * topk, N] | ||||||
|  |   at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       quantize_row_int8<scalar_t>( | ||||||
|  |           Aq_tmp + m * N, | ||||||
|  |           As_tmp[m], | ||||||
|  |           ic1 + m * N, | ||||||
|  |           N); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||||
|  |   //   w2 : [E, K, N] as [E, OC, IC] | ||||||
|  |   const int64_t OC = K;  // rename K as OC | ||||||
|  |   const int64_t IC = N;  // rename N as IC | ||||||
|  |   const int64_t MB2 = MB; | ||||||
|  |   const int64_t NB2 = div_up(OC, BLOCK_N); | ||||||
|  |   const int64_t stride_e2 = OC * packed_N; | ||||||
|  |   const int64_t stride_oc = packed_N; | ||||||
|  |  | ||||||
|  |   // parallel on [MB2, NB2] | ||||||
|  |   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     // get local pointers | ||||||
|  |     int tid = at::get_thread_num(); | ||||||
|  |     // we won't be using C1 for gemm2 | ||||||
|  |     float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; | ||||||
|  |  | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB2; | ||||||
|  |       int64_t nb = i % NB2; | ||||||
|  |  | ||||||
|  |       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||||
|  |       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||||
|  |  | ||||||
|  |       // A ptr from ic1 of [M * topk, N] in sorted order | ||||||
|  |       // so as to avoid copy A to tmp buffer again | ||||||
|  |       const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; | ||||||
|  |       const float* __restrict__ As = As_tmp + offsets[mb]; | ||||||
|  |       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||||
|  |  | ||||||
|  |       // B shape [IC, n_size] in vnni format | ||||||
|  |       int32_t expert_id = expert_ids[mb]; | ||||||
|  |       const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; | ||||||
|  |       const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; | ||||||
|  |  | ||||||
|  |       // 2.a gemm: C = A @ B | ||||||
|  |       tinygemm_kernel<scalar_t>( | ||||||
|  |           /* A     */ A, | ||||||
|  |           /* B     */ B, | ||||||
|  |           /* C     */ C, | ||||||
|  |           /* As    */ As, | ||||||
|  |           /* Bs    */ Bs, | ||||||
|  |           /* M     */ m_size, | ||||||
|  |           /* N     */ n_size, | ||||||
|  |           /* K     */ IC, | ||||||
|  |           /* lda   */ IC, | ||||||
|  |           /* ldb   */ n_size, | ||||||
|  |           /* ldc   */ BLOCK_N); | ||||||
|  |  | ||||||
|  |       // 2.b copy from C to ic2 in original order | ||||||
|  |       //   and also mul topk_weights in float32 | ||||||
|  |       for (int64_t m = 0; m < m_size; ++m) { | ||||||
|  |         int32_t index = A_ids[m]; | ||||||
|  |         float weight = topk_weights[index]; | ||||||
|  |         copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 3: out = intermediate_cache2.sum(dim=1) | ||||||
|  |   //   from [M, topk, K] to [M, K] | ||||||
|  |   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       sum_stub(output + m * K, ic2 + m * topk * K, topk, K); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE)                                                  \ | ||||||
|  |   template void fused_experts_int8_kernel_impl<TYPE> (                                       \ | ||||||
|  |       TYPE* __restrict__ output, TYPE* __restrict__ ic1,                                     \ | ||||||
|  |       TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp,                                   \ | ||||||
|  |       float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp,                               \ | ||||||
|  |       float* __restrict__ As_tmp, const TYPE* __restrict__ input,                            \ | ||||||
|  |       const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2,            \ | ||||||
|  |       const float* __restrict__ w1s, const float* __restrict__ w2s,                          \ | ||||||
|  |       const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids,        \ | ||||||
|  |       const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets,           \ | ||||||
|  |       int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad) | ||||||
|  |  | ||||||
|  | INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); | ||||||
|  | INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | void shared_expert_int8_kernel_impl( | ||||||
|  |     scalar_t* __restrict__ output, | ||||||
|  |     scalar_t* __restrict__ ic1, | ||||||
|  |     float* __restrict__ C_tmp, | ||||||
|  |     uint8_t* __restrict__ Aq_tmp, | ||||||
|  |     float* __restrict__ As_tmp, | ||||||
|  |     const scalar_t* __restrict__ input, | ||||||
|  |     const int8_t* __restrict__ packed_w1, | ||||||
|  |     const int8_t* __restrict__ packed_w2, | ||||||
|  |     const float* __restrict__ w1s, | ||||||
|  |     const float* __restrict__ w2s, | ||||||
|  |     const scalar_t* __restrict__ fused_experts_out, | ||||||
|  |     float routed_scaling_factor, | ||||||
|  |     int64_t M, | ||||||
|  |     int64_t N, | ||||||
|  |     int64_t K) { | ||||||
|  |  | ||||||
|  |   // handle 2 tiles per block | ||||||
|  |   constexpr int64_t BLOCK_M = block_size_m(); | ||||||
|  |   constexpr int64_t BLOCK_N = block_size_n(); | ||||||
|  |  | ||||||
|  |   // stage 0: quantize input to uint8, [M, K] | ||||||
|  |   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       quantize_row_int8<scalar_t>( | ||||||
|  |           Aq_tmp + m * K, | ||||||
|  |           As_tmp[m], | ||||||
|  |           input + m * K, | ||||||
|  |           K); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |    // stage 1: intermediate_cache1 = silu(hidden_states @ w1) | ||||||
|  |   const int64_t MB = div_up(M, BLOCK_M); | ||||||
|  |   const int64_t NB = div_up(N, BLOCK_N); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); | ||||||
|  |  | ||||||
|  |   // K and N are packed for int8 | ||||||
|  |   const int64_t packed_K = get_row_size<int8_t>(K); | ||||||
|  |   const int64_t packed_N = get_row_size<int8_t>(N); | ||||||
|  |   const int64_t stride_n = packed_K; | ||||||
|  |  | ||||||
|  |   // here we only parallel on half of 2N to fuse silu_and_mul with gemm | ||||||
|  |   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB; | ||||||
|  |       int64_t nb = i % NB; | ||||||
|  |  | ||||||
|  |       // nb0 from top half and nb1 from bottom half | ||||||
|  |       int64_t nb0 = nb, nb1 = nb + NB; | ||||||
|  |       int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); | ||||||
|  |       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||||
|  |  | ||||||
|  |       // A shape [m_size, K] | ||||||
|  |       const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; | ||||||
|  |       const float* As = As_tmp + mb * BLOCK_M; | ||||||
|  |  | ||||||
|  |       // B shape [K, n_size] in vnni format | ||||||
|  |       const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; | ||||||
|  |       const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; | ||||||
|  |       const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; | ||||||
|  |       const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; | ||||||
|  |  | ||||||
|  |       // fused 1.b: silu_and_mul(A @ B0, A @ B1) | ||||||
|  |       tinygemm_kernel( | ||||||
|  |           /* A     */ A, | ||||||
|  |           /* B0    */ B0, | ||||||
|  |           /* B1    */ B1, | ||||||
|  |           /* C     */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, | ||||||
|  |           /* As    */ As, | ||||||
|  |           /* Bs0   */ Bs0, | ||||||
|  |           /* Bs1   */ Bs1, | ||||||
|  |           /* M     */ m_size, | ||||||
|  |           /* N     */ n_size, | ||||||
|  |           /* K     */ K, | ||||||
|  |           /* lda   */ K, | ||||||
|  |           /* ldb   */ n_size, | ||||||
|  |           /* ldc   */ N); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 1.5: quantize ic1 to uint8, [M * topk, N] | ||||||
|  |   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     for (int64_t m = begin; m < end; ++m) { | ||||||
|  |       quantize_row_int8<scalar_t>( | ||||||
|  |           Aq_tmp + m * N, | ||||||
|  |           As_tmp[m], | ||||||
|  |           ic1 + m * N, | ||||||
|  |           N); | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  |  | ||||||
|  |   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||||
|  |   //   w2 : [K, N] as [OC, IC] | ||||||
|  |   const int64_t OC = K;  // rename K as OC | ||||||
|  |   const int64_t IC = N;  // rename N as IC | ||||||
|  |   const int64_t MB2 = MB; | ||||||
|  |   const int64_t NB2 = div_up(OC, BLOCK_N); | ||||||
|  |   const int64_t stride_oc = packed_N; | ||||||
|  |  | ||||||
|  |   // parallel on [MB2, NB2] | ||||||
|  |   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||||
|  |     // get local pointers | ||||||
|  |     int tid = at::get_thread_num(); | ||||||
|  |     // we won't be using C1 for gemm2 | ||||||
|  |     float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; | ||||||
|  |  | ||||||
|  |     for (int64_t i = begin; i < end; ++i) { | ||||||
|  |       int64_t mb = i / NB2; | ||||||
|  |       int64_t nb = i % NB2; | ||||||
|  |  | ||||||
|  |       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||||
|  |       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||||
|  |  | ||||||
|  |       // A shape [m_size, IC] | ||||||
|  |       const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; | ||||||
|  |       const float* __restrict__ As = As_tmp + mb * BLOCK_M; | ||||||
|  |  | ||||||
|  |       // B shape [IC, n_size] in vnni format | ||||||
|  |       const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; | ||||||
|  |       const float* __restrict__ Bs = w2s + nb * BLOCK_N; | ||||||
|  |  | ||||||
|  |       // 2.a gemm: C = A @ B | ||||||
|  |       tinygemm_kernel<scalar_t>( | ||||||
|  |           /* A     */ A, | ||||||
|  |           /* B     */ B, | ||||||
|  |           /* C     */ C, | ||||||
|  |           /* As    */ As, | ||||||
|  |           /* Bs    */ Bs, | ||||||
|  |           /* M     */ m_size, | ||||||
|  |           /* N     */ n_size, | ||||||
|  |           /* K     */ IC, | ||||||
|  |           /* lda   */ IC, | ||||||
|  |           /* ldb   */ n_size, | ||||||
|  |           /* ldc   */ BLOCK_N); | ||||||
|  |  | ||||||
|  |       // 2.b copy from C to output and add fused_experts_out | ||||||
|  |       scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; | ||||||
|  |       const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; | ||||||
|  |       for (int64_t m = 0; m < m_size; ++m) { | ||||||
|  |         add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE)                                        \ | ||||||
|  |   template void shared_expert_int8_kernel_impl<TYPE> (                                       \ | ||||||
|  |       TYPE* __restrict__ output, TYPE* __restrict__ ic1,                                     \ | ||||||
|  |       float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp,                               \ | ||||||
|  |       float* __restrict__ As_tmp, const TYPE* __restrict__ input,                            \ | ||||||
|  |       const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2,            \ | ||||||
|  |       const float* __restrict__ w1s, const float* __restrict__ w2s,                          \ | ||||||
|  |       const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor,               \ | ||||||
|  |       int64_t M, int64_t N, int64_t K) | ||||||
|  |  | ||||||
|  | INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); | ||||||
|  | INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); | ||||||
							
								
								
									
										308
									
								
								csrc/cpu/sgl-kernels/vec.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										308
									
								
								csrc/cpu/sgl-kernels/vec.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,308 @@ | |||||||
|  | // Adapted from | ||||||
|  | // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||||
|  |  | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | #if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) | ||||||
|  | #define CPU_CAPABILITY_AVX512 | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #include <ATen/cpu/vec/functional.h> | ||||||
|  | #include <ATen/cpu/vec/vec.h> | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | using namespace at::vec; | ||||||
|  |  | ||||||
|  | template <typename scalar_t, | ||||||
|  |           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> | ||||||
|  | inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) { | ||||||
|  |   return at::vec::convert_from_float<scalar_t>(a, b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  |  | ||||||
|  | // `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics | ||||||
|  | // use native instruction for bfloat16->float32 conversion | ||||||
|  | template <> | ||||||
|  | inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) { | ||||||
|  |   return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define CVT_BF16_TO_FP32(a) \ | ||||||
|  |     _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) | ||||||
|  |  | ||||||
|  | #define CVT_FP16_TO_FP32(a) \ | ||||||
|  |     _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) | ||||||
|  |  | ||||||
|  | // this doesn't handle NaN. | ||||||
|  | inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { | ||||||
|  |   const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); | ||||||
|  |  | ||||||
|  |   const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); | ||||||
|  |   const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); | ||||||
|  |   const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); | ||||||
|  |   const __m512i nonsign = _mm512_or_si512(exp, mant); | ||||||
|  |  | ||||||
|  |   const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); | ||||||
|  |   const __m512i combined = _mm512_or_si512(nonsign, sign); | ||||||
|  |  | ||||||
|  |   const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); | ||||||
|  |   return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { | ||||||
|  |   // The following conversion is without denorm behavior, that is to say, | ||||||
|  |   //   Max subnorm   : S.0000.111 = 0.875 ∗ 2**(−6) | ||||||
|  |   //   Min subnorm   : S.0000.001 = 2**(−9) | ||||||
|  |   // 0.0019 ~ 0.0137 cannot be converted correctly. | ||||||
|  |   __m512i x = _mm512_cvtepu8_epi16(fp8_vec); | ||||||
|  |   auto mask = _mm512_cmpneq_epi16_mask( | ||||||
|  |       _mm512_and_si512(x, _mm512_set1_epi16(127)), | ||||||
|  |       _mm512_setzero_si512());  // mask = x & 0x7f | ||||||
|  |   auto mask_nan = _mm512_cmpneq_epi16_mask( | ||||||
|  |       _mm512_and_si512(x, _mm512_set1_epi16(127)), | ||||||
|  |       _mm512_set1_epi16(127));                                                      // mask_nan = x & 0x7f | ||||||
|  |   auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4);  // mantissa = (x & 7) << 4 | ||||||
|  |   auto exponent = _mm512_add_epi16( | ||||||
|  |       _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), | ||||||
|  |       _mm512_set1_epi16(120));  // exponent = (((x >> 3) & 15) + 120) | ||||||
|  |   auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7))); | ||||||
|  |   nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign);  // deal with Nan | ||||||
|  |   return (__m512bh)(_mm512_or_si512( | ||||||
|  |       nonsign, | ||||||
|  |       _mm512_slli_epi16( | ||||||
|  |           _mm512_and_si512(x, _mm512_set1_epi16(128)), | ||||||
|  |           8)));  // add sign (x & 128) << 8 | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { | ||||||
|  |   __m512i x = _mm512_cvtepu8_epi16(fp8_vec); | ||||||
|  |   __m512i lg2mant = _mm512_mask_mov_epi16( | ||||||
|  |       _mm512_mask_mov_epi16( | ||||||
|  |           _mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)), | ||||||
|  |       _mm512_test_epi16_mask(x, _mm512_set1_epi16(4)), | ||||||
|  |       _mm512_set1_epi16(2)); | ||||||
|  |   return (__m512bh)(_mm512_or_si512( | ||||||
|  |       _mm512_maskz_mov_epi16( | ||||||
|  |           _mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()), | ||||||
|  |           _mm512_mask_blend_epi16( | ||||||
|  |               _mm512_test_epi16_mask(x, _mm512_set1_epi16(120)), | ||||||
|  |               _mm512_or_si512( | ||||||
|  |                   _mm512_and_si512( | ||||||
|  |                       _mm512_sllv_epi16( | ||||||
|  |                           _mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)), | ||||||
|  |                       _mm512_set1_epi16(0x007f)), | ||||||
|  |                   _mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)), | ||||||
|  |               _mm512_or_si512( | ||||||
|  |                   _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4), | ||||||
|  |                   _mm512_slli_epi16( | ||||||
|  |                       _mm512_add_epi16( | ||||||
|  |                           _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)), | ||||||
|  |                       7)))), | ||||||
|  |       _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8))); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline __m512bh CVT_FP8_TO_BF16(__m256i a) { | ||||||
|  | #ifdef SGLANG_CPU_FP8_CVT_FTZ | ||||||
|  |   return cvt_e4m3_bf16_intrinsic_no_nan(a); | ||||||
|  | #else | ||||||
|  |   return cvt_e4m3_bf16_intrinsic_with_denorm(a); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | // vector to scalar reduction | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) && 0 | ||||||
|  | inline float vec_reduce_sum(const Vectorized<float>& a) { | ||||||
|  |   return _mm512_reduce_add_ps(__m512(a)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline float vec_reduce_max(const Vectorized<float>& a) { | ||||||
|  |   return _mm512_reduce_max_ps(__m512(a)); | ||||||
|  | } | ||||||
|  | #else | ||||||
|  | inline float vec_reduce_sum(const Vectorized<float>& a) { | ||||||
|  |   return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline float vec_reduce_max(const Vectorized<float>& a) { | ||||||
|  |   return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a); | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | // https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 | ||||||
|  | template <typename scalar_t> | ||||||
|  | inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, | ||||||
|  |     const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { | ||||||
|  |  | ||||||
|  |   float amax = 0.f; // absolute max | ||||||
|  |   for (int64_t k = 0; k < K; ++k) { | ||||||
|  |     const float val = static_cast<float>(A[k]); | ||||||
|  |     amax = std::max(amax, std::abs(val)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   amax = std::max(amax, eps); | ||||||
|  |   const float scale = amax / 127; | ||||||
|  |   const float inv_scale = 127 / amax; | ||||||
|  |  | ||||||
|  |   for (int64_t k = 0; k < K; ++k) { | ||||||
|  |     const float val = static_cast<float>(A[k]) * inv_scale; | ||||||
|  |     Aq[k] = (uint8_t)(std::round(val)) + 128; | ||||||
|  |   } | ||||||
|  |   As = scale; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  | template <> | ||||||
|  | inline void quantize_row_int8<at::BFloat16>(uint8_t* __restrict__ Aq, float& As, | ||||||
|  |     const at::BFloat16* __restrict__ A, int64_t K, float eps) { | ||||||
|  |  | ||||||
|  |   const __m512 signBit = _mm512_set1_ps(-0.0f); | ||||||
|  |   const __m512i off = _mm512_set1_epi32(128); | ||||||
|  |  | ||||||
|  |   // K is 32x, no remainder | ||||||
|  |   float amax = 0.f; | ||||||
|  |   __m512 vamax0 = _mm512_set1_ps(0.f); | ||||||
|  |   __m512 vamax1 = _mm512_set1_ps(0.f); | ||||||
|  |   for (int64_t k = 0; k < K; k += 32) { | ||||||
|  |     __m512i va = _mm512_loadu_si512((void*)(A + k)); | ||||||
|  |     __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); | ||||||
|  |     __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); | ||||||
|  |     vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); | ||||||
|  |     vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); | ||||||
|  |   } | ||||||
|  |   amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); | ||||||
|  |   amax = std::max(amax, eps); | ||||||
|  |   const float scale = amax / 127; | ||||||
|  |   const float inv_scale = 127 / amax; | ||||||
|  |   const __m512 vd = _mm512_set1_ps(inv_scale); | ||||||
|  |  | ||||||
|  |   for (int64_t k = 0; k < K; k += 32) { | ||||||
|  |     __m512i va = _mm512_loadu_si512((void*)(A + k)); | ||||||
|  |     __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); | ||||||
|  |     __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); | ||||||
|  |     va0 = _mm512_mul_ps(va0, vd); | ||||||
|  |     va1 = _mm512_mul_ps(va1, vd); | ||||||
|  |     va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); | ||||||
|  |     va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); | ||||||
|  |     __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); | ||||||
|  |     __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); | ||||||
|  |     _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); | ||||||
|  |   } | ||||||
|  |   As = scale; | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | // transpose utils | ||||||
|  | // taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998 | ||||||
|  | #if defined(CPU_CAPABILITY_AVX512) | ||||||
|  | inline void transpose_16x16_32bit(__m512i * v) { | ||||||
|  |   __m512i v1[16]; | ||||||
|  |   v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); | ||||||
|  |   v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); | ||||||
|  |   v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); | ||||||
|  |   v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); | ||||||
|  |   v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); | ||||||
|  |   v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); | ||||||
|  |   v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); | ||||||
|  |   v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); | ||||||
|  |   v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); | ||||||
|  |   v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); | ||||||
|  |   v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); | ||||||
|  |   v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); | ||||||
|  |   v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); | ||||||
|  |   v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); | ||||||
|  |   v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); | ||||||
|  |   v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); | ||||||
|  |  | ||||||
|  |   v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); | ||||||
|  |   v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); | ||||||
|  |   v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); | ||||||
|  |   v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); | ||||||
|  |   v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); | ||||||
|  |   v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); | ||||||
|  |   v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); | ||||||
|  |   v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); | ||||||
|  |   v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); | ||||||
|  |   v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); | ||||||
|  |   v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); | ||||||
|  |   v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); | ||||||
|  |   v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); | ||||||
|  |   v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); | ||||||
|  |   v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); | ||||||
|  |   v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); | ||||||
|  |  | ||||||
|  |   v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); | ||||||
|  |   v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); | ||||||
|  |   v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); | ||||||
|  |   v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); | ||||||
|  |   v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); | ||||||
|  |   v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); | ||||||
|  |   v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); | ||||||
|  |   v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); | ||||||
|  |   v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); | ||||||
|  |   v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); | ||||||
|  |   v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); | ||||||
|  |   v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); | ||||||
|  |   v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); | ||||||
|  |   v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); | ||||||
|  |   v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); | ||||||
|  |   v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); | ||||||
|  |  | ||||||
|  |   v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); | ||||||
|  |   v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); | ||||||
|  |   v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); | ||||||
|  |   v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); | ||||||
|  |   v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); | ||||||
|  |   v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); | ||||||
|  |   v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); | ||||||
|  |   v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); | ||||||
|  |   v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); | ||||||
|  |   v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); | ||||||
|  |   v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); | ||||||
|  |   v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); | ||||||
|  |   v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); | ||||||
|  |   v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); | ||||||
|  |   v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); | ||||||
|  |   v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes] | ||||||
|  | #pragma GCC diagnostic push | ||||||
|  | #pragma GCC diagnostic ignored "-Wignored-attributes" | ||||||
|  |  | ||||||
|  | // transpose from [2, 32] to [32, 2] | ||||||
|  | inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) { | ||||||
|  |   // r0: {a0, a1, ..., a31} | ||||||
|  |   // r1: {b0, b1, ..., b31} | ||||||
|  |   // | ||||||
|  |   // d0: {a0,   b0, ..., a15, b15} | ||||||
|  |   // d1: {a16, b16, ..., a31, b31} | ||||||
|  |   // | ||||||
|  |   __m512i d0 = _mm512_unpacklo_epi16(r0, r1); | ||||||
|  |   __m512i d1 = _mm512_unpackhi_epi16(r0, r1); | ||||||
|  |   r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); | ||||||
|  |   r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); | ||||||
|  |   d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); | ||||||
|  |   d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); | ||||||
|  |   return std::make_tuple(d0, d1); | ||||||
|  | } | ||||||
|  | #pragma GCC diagnostic pop | ||||||
|  |  | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | // TODO: debug print, remove me later | ||||||
|  | template<typename scalar_t> | ||||||
|  | void print_array(scalar_t* ptr, int size) { | ||||||
|  |   for (int d = 0; d < size; ++d) { | ||||||
|  |     if (d % 16 == 0) { std::cout << std::endl; } | ||||||
|  |     std::cout << ptr[d] << " "; | ||||||
|  |   } | ||||||
|  |   std::cout << std::endl; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // anonymous namespace | ||||||
							
								
								
									
										178
									
								
								csrc/cpu/shm.cpp
									
									
									
									
									
								
							
							
						
						
									
										178
									
								
								csrc/cpu/shm.cpp
									
									
									
									
									
								
							| @ -7,9 +7,10 @@ | |||||||
|  |  | ||||||
| namespace { | namespace { | ||||||
| #define MAX_SHM_RANK_NUM 8 | #define MAX_SHM_RANK_NUM 8 | ||||||
| #define MAX_THREAD_NUM 12 | #define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024) | ||||||
| #define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) | static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0); | ||||||
| #define MIN_THREAD_PROCESS_SIZE (8 * 1024) | #define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1) | ||||||
|  | #define MIN_THREAD_PROCESS_SIZE (256) | ||||||
| #define MAX_P2P_SEND_TENSOR_NUM 8 | #define MAX_P2P_SEND_TENSOR_NUM 8 | ||||||
|  |  | ||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| @ -32,10 +33,10 @@ struct KernelVecType<c10::Half> { | |||||||
|   using scalar_vec_t = vec_op::FP16Vec16; |   using scalar_vec_t = vec_op::FP16Vec16; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE }; |  | ||||||
|  |  | ||||||
| struct ThreadSHMContext { | struct ThreadSHMContext { | ||||||
|   volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM]; |   volatile char _curr_thread_stamp; | ||||||
|  |   volatile char _ready_thread_stamp; | ||||||
|  |   char _padding1[6]; | ||||||
|   int thread_id; |   int thread_id; | ||||||
|   int thread_num; |   int thread_num; | ||||||
|   int rank; |   int rank; | ||||||
| @ -44,14 +45,19 @@ struct ThreadSHMContext { | |||||||
|   int swizzled_ranks[MAX_SHM_RANK_NUM]; |   int swizzled_ranks[MAX_SHM_RANK_NUM]; | ||||||
|   void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; |   void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; | ||||||
|   ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; |   ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; | ||||||
|  |   size_t _thread_buffer_mask; | ||||||
|  |   char _padding2[56]; | ||||||
|  |  | ||||||
|   ThreadSHMContext(const int thread_id, const int thread_num, const int rank, |   ThreadSHMContext(const int thread_id, const int thread_num, const int rank, | ||||||
|                    const int group_size, void* thread_shm_ptr) |                    const int group_size, void* thread_shm_ptr) | ||||||
|       : thread_id(thread_id), |       : _curr_thread_stamp(1), | ||||||
|  |         _ready_thread_stamp(0), | ||||||
|  |         thread_id(thread_id), | ||||||
|         thread_num(thread_num), |         thread_num(thread_num), | ||||||
|         rank(rank), |         rank(rank), | ||||||
|         group_size(group_size), |         group_size(group_size), | ||||||
|         _spinning_count(0) { |         _spinning_count(0), | ||||||
|  |         _thread_buffer_mask(0) { | ||||||
|     static_assert(sizeof(ThreadSHMContext) % 64 == 0); |     static_assert(sizeof(ThreadSHMContext) % 64 == 0); | ||||||
|     TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); |     TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); | ||||||
|     TORCH_CHECK((size_t)this % 64 == 0); |     TORCH_CHECK((size_t)this % 64 == 0); | ||||||
| @ -60,7 +66,6 @@ struct ThreadSHMContext { | |||||||
|       shm_contexts[i] = nullptr; |       shm_contexts[i] = nullptr; | ||||||
|       thread_shm_ptrs[i] = nullptr; |       thread_shm_ptrs[i] = nullptr; | ||||||
|       swizzled_ranks[i] = (i + rank) % group_size; |       swizzled_ranks[i] = (i + rank) % group_size; | ||||||
|       thread_stats[i] = ThreadSHMStat::DONE; |  | ||||||
|     } |     } | ||||||
|     set_context(rank, this, thread_shm_ptr); |     set_context(rank, this, thread_shm_ptr); | ||||||
|   } |   } | ||||||
| @ -77,59 +82,66 @@ struct ThreadSHMContext { | |||||||
|  |  | ||||||
|   template <typename T> |   template <typename T> | ||||||
|   T* get_thread_shm_ptr(int rank) { |   T* get_thread_shm_ptr(int rank) { | ||||||
|     return reinterpret_cast<T*>(thread_shm_ptrs[rank]); |     return reinterpret_cast<T*>( | ||||||
|  |         reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) + | ||||||
|  |         (PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; } | ||||||
|  |  | ||||||
|  |   char get_curr_stamp() const { return _curr_thread_stamp; } | ||||||
|  |  | ||||||
|  |   char get_ready_stamp() const { return _ready_thread_stamp; } | ||||||
|  |  | ||||||
|  |   void next_stamp() { | ||||||
|  |     _mm_mfence(); | ||||||
|  |     _curr_thread_stamp += 1; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void commit_ready_stamp() { | ||||||
|  |     _mm_mfence(); | ||||||
|  |     _ready_thread_stamp = _curr_thread_stamp; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } |   int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } | ||||||
|  |  | ||||||
|   void wait_for_all(ThreadSHMStat prev_stat) { |   template <typename Cond> | ||||||
|     for (int idx = 0; idx < group_size; ++idx) { |   void wait_for_all(Cond&& cond) { | ||||||
|  |     for (int idx = 1; idx < group_size; ++idx) { | ||||||
|       int rank = get_swizzled_rank(idx); |       int rank = get_swizzled_rank(idx); | ||||||
|       while (thread_stats[rank] == prev_stat) { |       wait_for_one(rank, std::forward<Cond>(cond)); | ||||||
|         ++_spinning_count; |  | ||||||
|         _mm_pause(); |  | ||||||
|       } |  | ||||||
|     } |     } | ||||||
|     vec_op::mem_barrier(); |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   void wait_for_one(int rank, ThreadSHMStat prev_stat) { |   template <typename Cond> | ||||||
|     while (thread_stats[rank] == prev_stat) { |   void wait_for_one(int rank, Cond&& cond) { | ||||||
|  |     ThreadSHMContext* rank_ctx = shm_contexts[rank]; | ||||||
|  |     for (;;) { | ||||||
|  |       char local_curr_stamp = get_curr_stamp(); | ||||||
|  |       char local_ready_stamp = get_ready_stamp(); | ||||||
|  |       char rank_curr_stamp = rank_ctx->get_curr_stamp(); | ||||||
|  |       char rank_ready_stamp = rank_ctx->get_ready_stamp(); | ||||||
|  |       if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp, | ||||||
|  |                rank_ready_stamp)) { | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|       ++_spinning_count; |       ++_spinning_count; | ||||||
|       _mm_pause(); |       _mm_pause(); | ||||||
|     } |     } | ||||||
|     vec_op::mem_barrier(); |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   void set_thread_stat(ThreadSHMStat stat) { |   static bool check_no_buffer_conflict(char local_curr_stamp, | ||||||
|     for (int idx = 0; idx < group_size; ++idx) { |                                        char local_ready_stamp, | ||||||
|       int rank = get_swizzled_rank(idx); |                                        char rank_curr_stamp, | ||||||
|       shm_contexts[rank]->thread_stats[this->rank] = stat; |                                        char rank_ready_stamp) { | ||||||
|     } |     char temp = rank_curr_stamp + 2; | ||||||
|  |     return local_curr_stamp != temp; | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   void set_thread_stat(int target_rank, ThreadSHMStat stat) { |   static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp, | ||||||
|     for (int idx = 0; idx < group_size; ++idx) { |                                 char rank_curr_stamp, char rank_ready_stamp) { | ||||||
|       int rank = get_swizzled_rank(idx); |     char temp = local_curr_stamp + 1; | ||||||
|       shm_contexts[rank]->thread_stats[target_rank] = stat; |     return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp); | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   // barrier for all ranks in the group, used for all2all ops |  | ||||||
|   // DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ... |  | ||||||
|   void barrier(ThreadSHMStat next_stat) { |  | ||||||
|     if (next_stat == ThreadSHMStat::THREAD_READY) { |  | ||||||
|       set_thread_stat(ThreadSHMStat::THREAD_READY); |  | ||||||
|       wait_for_all(ThreadSHMStat::DONE); |  | ||||||
|     } else if (next_stat == ThreadSHMStat::SHM_DATA_READY) { |  | ||||||
|       set_thread_stat(ThreadSHMStat::SHM_DATA_READY); |  | ||||||
|       wait_for_all(ThreadSHMStat::THREAD_READY); |  | ||||||
|     } else if (next_stat == ThreadSHMStat::DONE) { |  | ||||||
|       set_thread_stat(ThreadSHMStat::DONE); |  | ||||||
|       wait_for_all(ThreadSHMStat::SHM_DATA_READY); |  | ||||||
|     } else { |  | ||||||
|       TORCH_CHECK(false, "Invalid next_stat to barrier."); |  | ||||||
|     } |  | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   std::string to_string() const { |   std::string to_string() const { | ||||||
| @ -164,7 +176,7 @@ class SHMManager { | |||||||
|                       const int group_size) |                       const int group_size) | ||||||
|       : _rank(rank), |       : _rank(rank), | ||||||
|         _group_size(group_size), |         _group_size(group_size), | ||||||
|         _thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)), |         _thread_num(torch::get_num_threads()), | ||||||
|         _shm_names({""}), |         _shm_names({""}), | ||||||
|         _shared_mem_ptrs({nullptr}), |         _shared_mem_ptrs({nullptr}), | ||||||
|         _shm_ctx(nullptr) { |         _shm_ctx(nullptr) { | ||||||
| @ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { | |||||||
|       (total_units_num + thread_num - 1) / thread_num; |       (total_units_num + thread_num - 1) / thread_num; | ||||||
|   int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); |   int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); | ||||||
|   int64_t max_per_thread_iteration_elem_num = |   int64_t max_per_thread_iteration_elem_num = | ||||||
|       PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t); |       (PER_THREAD_SHM_BUFFER_BYTES >> 1) / | ||||||
|  |       sizeof(scalar_t);  // Note: double buffer | ||||||
|   int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; |   int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; | ||||||
|  |  | ||||||
| #pragma omp parallel for schedule(static, 1) | #pragma omp parallel for schedule(static, 1) | ||||||
| @ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { | |||||||
|     int64_t curr_elem_num = |     int64_t curr_elem_num = | ||||||
|         std::min(max_per_thread_iteration_elem_num, end - offset); |         std::min(max_per_thread_iteration_elem_num, end - offset); | ||||||
|     ThreadSHMContext* thread_ctx = ctx + i; |     ThreadSHMContext* thread_ctx = ctx + i; | ||||||
|  |     bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num); | ||||||
|  |  | ||||||
|     while (curr_elem_num > 0) { |     while (curr_elem_num > 0) { | ||||||
|       inner_func(thread_ctx, offset, curr_elem_num); |       inner_func(thread_ctx, offset, curr_elem_num, fast_mode); | ||||||
|  |  | ||||||
|  |       thread_ctx->next_stamp(); | ||||||
|  |       thread_ctx->next_buffer(); | ||||||
|       offset += max_per_thread_iteration_elem_num; |       offset += max_per_thread_iteration_elem_num; | ||||||
|       curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); |       curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); | ||||||
|     } |     } | ||||||
| @ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, | |||||||
|   shm_cc_ops::shm_cc_loop<scalar_t>( |   shm_cc_ops::shm_cc_loop<scalar_t>( | ||||||
|       ctx, elem_num, |       ctx, elem_num, | ||||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, |       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||||
|           int64_t data_elem_num) { |           int64_t data_elem_num, bool fast_mode) { | ||||||
|         int rank = thread_ctx->rank; |         int rank = thread_ctx->rank; | ||||||
|         scalar_t* thread_shm_ptr = |         scalar_t* thread_shm_ptr = | ||||||
|             thread_ctx->get_thread_shm_ptr<scalar_t>(rank); |             thread_ctx->get_thread_shm_ptr<scalar_t>(rank); | ||||||
| @ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, | |||||||
|               thread_ctx->get_swizzled_rank(idx + 1)); |               thread_ctx->get_swizzled_rank(idx + 1)); | ||||||
|         }); |         }); | ||||||
|  |  | ||||||
|         thread_ctx->barrier(ThreadSHMStat::THREAD_READY); |         if (!fast_mode) { | ||||||
|  |           thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, |         shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, | ||||||
|                                   thread_data_elem_num); |                                   thread_data_elem_num); | ||||||
|  |         thread_ctx->commit_ready_stamp(); | ||||||
|         thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); |  | ||||||
|  |  | ||||||
|         int64_t aligned_data_elem_num = |         int64_t aligned_data_elem_num = | ||||||
|             (data_elem_num / vec_elem_num) * vec_elem_num; |             (data_elem_num / vec_elem_num) * vec_elem_num; | ||||||
|         int64_t i = 0; |         int64_t i = 0; | ||||||
|  |         thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready); | ||||||
| #pragma GCC unroll 4 | #pragma GCC unroll 4 | ||||||
|         for (; i < aligned_data_elem_num; i += vec_elem_num) { |         for (; i < aligned_data_elem_num; i += vec_elem_num) { | ||||||
|           vec_t local_data(thread_data_ptr + i);  // load from cache |           vec_t local_data(thread_data_ptr + i);  // load from cache | ||||||
| @ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, | |||||||
|           reduced_data.save(thread_data_ptr + i, |           reduced_data.save(thread_data_ptr + i, | ||||||
|                             data_elem_num - aligned_data_elem_num); |                             data_elem_num - aligned_data_elem_num); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         thread_ctx->barrier(ThreadSHMStat::DONE); |  | ||||||
|       }); |       }); | ||||||
|  |  | ||||||
|   return; |   return; | ||||||
| @ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, | |||||||
|   shm_cc_ops::shm_cc_loop<scalar_t>( |   shm_cc_ops::shm_cc_loop<scalar_t>( | ||||||
|       ctx, elem_num, |       ctx, elem_num, | ||||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, |       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||||
|           int64_t data_elem_num) { |           int64_t data_elem_num, bool fast_mode) { | ||||||
|         int rank = thread_ctx->rank; |         int rank = thread_ctx->rank; | ||||||
|         scalar_t* thread_shm_ptr = |         scalar_t* thread_shm_ptr = | ||||||
|             thread_ctx->get_thread_shm_ptr<scalar_t>(rank); |             thread_ctx->get_thread_shm_ptr<scalar_t>(rank); | ||||||
|  |  | ||||||
|         thread_ctx->barrier(ThreadSHMStat::THREAD_READY); |         if (!fast_mode) { | ||||||
|  |           thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); | ||||||
|         shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset, |         } | ||||||
|                                   data_elem_num * sizeof(scalar_t)); |  | ||||||
|  |  | ||||||
|         thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); |  | ||||||
|  |  | ||||||
|  |         shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset, | ||||||
|  |                            data_elem_num * sizeof(scalar_t)); | ||||||
|  |         thread_ctx->commit_ready_stamp(); | ||||||
|         if (rank == dst) { |         if (rank == dst) { | ||||||
|           shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, |           shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, | ||||||
|                              data_elem_num * sizeof(scalar_t)); |                              data_elem_num * sizeof(scalar_t)); | ||||||
| @ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, | |||||||
|             scalar_t* src_ptr = |             scalar_t* src_ptr = | ||||||
|                 thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank);  // shm |                 thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank);  // shm | ||||||
|             scalar_t* dst_ptr = outputs[src_rank] + data_offset; |             scalar_t* dst_ptr = outputs[src_rank] + data_offset; | ||||||
|             shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr, |             thread_ctx->wait_for_one(src_rank, | ||||||
|                                         data_elem_num * sizeof(scalar_t)); |                                      ThreadSHMContext::check_stamp_ready); | ||||||
|  |             shm_cc_ops::memcpy(dst_ptr, src_ptr, | ||||||
|  |                                data_elem_num * sizeof(scalar_t)); | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         thread_ctx->barrier(ThreadSHMStat::DONE); |  | ||||||
|       }); |       }); | ||||||
|  |  | ||||||
|   return; |   return; | ||||||
| @ -599,7 +614,7 @@ struct TensorListMeta { | |||||||
|   int8_t _padding[40]; |   int8_t _padding[40]; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| void shm_send_tensor_list_impl(ThreadSHMContext* ctx, | void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst, | ||||||
|                                const std::vector<torch::Tensor>& tensor_list) { |                                const std::vector<torch::Tensor>& tensor_list) { | ||||||
|   CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) |   CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) | ||||||
|   std::vector<torch::Tensor> tensor_list_with_metadata; |   std::vector<torch::Tensor> tensor_list_with_metadata; | ||||||
| @ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, | |||||||
|   shm_cc_ops::shm_cc_loop<int8_t>( |   shm_cc_ops::shm_cc_loop<int8_t>( | ||||||
|       ctx, metadata->total_bytes, |       ctx, metadata->total_bytes, | ||||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, |       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||||
|           int64_t data_elem_num) { |           int64_t data_elem_num, bool fast_mode) { | ||||||
|         int rank = thread_ctx->rank; |         int rank = thread_ctx->rank; | ||||||
|         // Wait until the receiver set the stat to DONE |  | ||||||
|         thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY); |  | ||||||
|  |  | ||||||
|         int64_t curr_shm_offset = 0; |         int64_t curr_shm_offset = 0; | ||||||
|  |         thread_ctx->wait_for_one(dst, | ||||||
|  |                                  ThreadSHMContext::check_no_buffer_conflict); | ||||||
|         while (curr_shm_offset < data_elem_num) { |         while (curr_shm_offset < data_elem_num) { | ||||||
|           MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); |           MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); | ||||||
|           frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); |           frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); | ||||||
| @ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, | |||||||
|               frag.ptr, frag.size); |               frag.ptr, frag.size); | ||||||
|           curr_shm_offset += frag.size; |           curr_shm_offset += frag.size; | ||||||
|         } |         } | ||||||
|  |         thread_ctx->commit_ready_stamp(); | ||||||
|         thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY); |  | ||||||
|       }); |       }); | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -646,8 +659,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx, | |||||||
|   torch::Tensor metadata_tensor = |   torch::Tensor metadata_tensor = | ||||||
|       torch::empty({sizeof(TensorListMeta)}, options); |       torch::empty({sizeof(TensorListMeta)}, options); | ||||||
|  |  | ||||||
|   // Wait until the sender set the stat of the thread 0 to SHM_DATA_READY |   ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); | ||||||
|   ctx->wait_for_one(src, ThreadSHMStat::DONE); |  | ||||||
|   shm_cc_ops::memcpy(metadata_tensor.data_ptr(), |   shm_cc_ops::memcpy(metadata_tensor.data_ptr(), | ||||||
|                      ctx->get_thread_shm_ptr<void>(src), |                      ctx->get_thread_shm_ptr<void>(src), | ||||||
|                      sizeof(TensorListMeta)); |                      sizeof(TensorListMeta)); | ||||||
| @ -664,9 +676,8 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx, | |||||||
|   shm_cc_ops::shm_cc_loop<int8_t>( |   shm_cc_ops::shm_cc_loop<int8_t>( | ||||||
|       ctx, metadata.total_bytes, |       ctx, metadata.total_bytes, | ||||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, |       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||||
|           int64_t data_elem_num) { |           int64_t data_elem_num, bool fast_mode) { | ||||||
|         // Wait until the sender set the stat to SHM_DATA_READY |         ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); | ||||||
|         thread_ctx->wait_for_one(src, ThreadSHMStat::DONE); |  | ||||||
|         int64_t curr_shm_offset = 0; |         int64_t curr_shm_offset = 0; | ||||||
|         while (curr_shm_offset < data_elem_num) { |         while (curr_shm_offset < data_elem_num) { | ||||||
|           MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); |           MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); | ||||||
| @ -677,8 +688,6 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx, | |||||||
|               frag.size); |               frag.size); | ||||||
|           curr_shm_offset += frag.size; |           curr_shm_offset += frag.size; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE); |  | ||||||
|       }); |       }); | ||||||
|  |  | ||||||
|   std::vector<torch::Tensor> tensor_list; |   std::vector<torch::Tensor> tensor_list; | ||||||
| @ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle, | |||||||
|                           int64_t dst) { |                           int64_t dst) { | ||||||
|   CPU_KERNEL_GUARD_IN(shm_send_tensor_list) |   CPU_KERNEL_GUARD_IN(shm_send_tensor_list) | ||||||
|   shm_send_tensor_list_impl( |   shm_send_tensor_list_impl( | ||||||
|       SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list); |       SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst, | ||||||
|  |       tensor_list); | ||||||
|   CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) |   CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) { | |||||||
|   TORCH_CHECK(shm_manager); |   TORCH_CHECK(shm_manager); | ||||||
|   shm_manager->join(name); |   shm_manager->join(name); | ||||||
|   return shm_manager->get_shm_ctx()->to_string(); |   return shm_manager->get_shm_ctx()->to_string(); | ||||||
| } | } | ||||||
|  | |||||||
| @ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle, | |||||||
|  |  | ||||||
| std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src); | std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src); | ||||||
|  |  | ||||||
|  | at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, | ||||||
|  |                                 const std::optional<at::Tensor>& bias, | ||||||
|  |                                 bool is_vnni); | ||||||
|  |  | ||||||
|  | at::Tensor convert_weight_packed(at::Tensor& weight); | ||||||
|  |  | ||||||
|  | at::Tensor fused_experts_cpu( | ||||||
|  |     at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2, | ||||||
|  |     at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace, | ||||||
|  |     bool use_int8_w8a8, bool use_fp8_w8a16, | ||||||
|  |     const std::optional<at::Tensor>& w1_scale, | ||||||
|  |     const std::optional<at::Tensor>& w2_scale, | ||||||
|  |     const std::optional<std::vector<int64_t>> block_size, | ||||||
|  |     const std::optional<at::Tensor>& a1_scale, | ||||||
|  |     const std::optional<at::Tensor>& a2_scale, bool is_vnni); | ||||||
|  |  | ||||||
|  | at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, | ||||||
|  |                                      at::Tensor& scales2, | ||||||
|  |                                      const std::optional<at::Tensor>& bias, | ||||||
|  |                                      at::ScalarType out_dtype, bool is_vnni); | ||||||
|  |  | ||||||
| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||||
|   // vLLM custom ops |   // vLLM custom ops | ||||||
|  |  | ||||||
| @ -130,17 +151,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |||||||
|   ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); |   ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); | ||||||
|  |  | ||||||
|   // Quantization |   // Quantization | ||||||
| #ifdef __AVX512F__ | #if defined(__AVX512F__) || defined(__aarch64__) | ||||||
|  |   at::Tag stride_tag = at::Tag::needs_fixed_stride_order; | ||||||
|  |  | ||||||
|   // Compute int8 quantized tensor for given scaling factor. |   // Compute int8 quantized tensor for given scaling factor. | ||||||
|   ops.def( |   ops.def( | ||||||
|       "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," |       "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," | ||||||
|       "Tensor? azp) -> ()"); |       "Tensor? azp) -> ()", | ||||||
|  |       {stride_tag}); | ||||||
|   ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); |   ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); | ||||||
|  |  | ||||||
|   // Compute int8 quantized tensor and scaling factor |   // Compute int8 quantized tensor and scaling factor | ||||||
|   ops.def( |   ops.def( | ||||||
|       "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " |       "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " | ||||||
|       "Tensor!? azp) -> ()"); |       "Tensor!? azp) -> ()", | ||||||
|  |       {stride_tag}); | ||||||
|   ops.impl("dynamic_scaled_int8_quant", torch::kCPU, |   ops.impl("dynamic_scaled_int8_quant", torch::kCPU, | ||||||
|            &dynamic_scaled_int8_quant); |            &dynamic_scaled_int8_quant); | ||||||
|   // W8A8 GEMM, supporting symmetric per-tensor or per-row/column |   // W8A8 GEMM, supporting symmetric per-tensor or per-row/column | ||||||
| @ -148,7 +173,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |||||||
|   ops.def( |   ops.def( | ||||||
|       "cutlass_scaled_mm(Tensor! out, Tensor a," |       "cutlass_scaled_mm(Tensor! out, Tensor a," | ||||||
|       "                  Tensor b, Tensor a_scales," |       "                  Tensor b, Tensor a_scales," | ||||||
|       "                  Tensor b_scales, Tensor? bias) -> ()"); |       "                  Tensor b_scales, Tensor? bias) -> ()", | ||||||
|  |       {stride_tag}); | ||||||
|   ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); |   ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); | ||||||
|   // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column |   // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column | ||||||
|   // quantization. |   // quantization. | ||||||
| @ -156,7 +182,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |||||||
|       "cutlass_scaled_mm_azp(Tensor! out, Tensor a," |       "cutlass_scaled_mm_azp(Tensor! out, Tensor a," | ||||||
|       "                  Tensor b, Tensor a_scales," |       "                  Tensor b, Tensor a_scales," | ||||||
|       "                  Tensor b_scales, Tensor azp_adj," |       "                  Tensor b_scales, Tensor azp_adj," | ||||||
|       "                  Tensor? azp, Tensor? bias) -> ()"); |       "                  Tensor? azp, Tensor? bias) -> ()", | ||||||
|  |       {stride_tag}); | ||||||
|   ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); |   ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); | ||||||
| #elif defined(__powerpc64__) | #elif defined(__powerpc64__) | ||||||
|   // Compute int8 quantized tensor for given scaling factor. |   // Compute int8 quantized tensor for given scaling factor. | ||||||
| @ -209,6 +236,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |||||||
|   ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", |   ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", | ||||||
|           &shm_recv_tensor_list); |           &shm_recv_tensor_list); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  |   // sgl-kernels | ||||||
|  | #if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__) | ||||||
|  |   ops.def( | ||||||
|  |       "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? " | ||||||
|  |       "bias, bool is_vnni) -> Tensor"); | ||||||
|  |   ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); | ||||||
|  |   ops.def("convert_weight_packed(Tensor! weight) -> Tensor"); | ||||||
|  |   ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); | ||||||
|  |   ops.def( | ||||||
|  |       "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor " | ||||||
|  |       "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool " | ||||||
|  |       "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? " | ||||||
|  |       "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> " | ||||||
|  |       "Tensor"); | ||||||
|  |   ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); | ||||||
|  |   ops.def( | ||||||
|  |       "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, " | ||||||
|  |       "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); | ||||||
|  |   ops.impl("int8_scaled_mm_with_quant", torch::kCPU, | ||||||
|  |            &int8_scaled_mm_with_quant); | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { | TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { | ||||||
|  | |||||||
| @ -54,8 +54,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { | |||||||
|     *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); |     *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); | ||||||
|     int page_num = numa_migrate_pages(pid, src_mask, mask); |     int page_num = numa_migrate_pages(pid, src_mask, mask); | ||||||
|     if (page_num == -1) { |     if (page_num == -1) { | ||||||
|       TORCH_CHECK(false, |       TORCH_WARN("numa_migrate_pages failed. errno: " + std::to_string(errno)); | ||||||
|                   "numa_migrate_pages failed. errno: " + std::to_string(errno)); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // restrict memory allocation node. |     // restrict memory allocation node. | ||||||
| @ -105,4 +104,4 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { | |||||||
|  |  | ||||||
|   return ss.str(); |   return ss.str(); | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
|  | |||||||
| @ -4,10 +4,10 @@ | |||||||
|   #include <hip/hip_runtime.h> |   #include <hip/hip_runtime.h> | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #ifndef USE_ROCM | #if defined(USE_ROCM) && defined(__GFX9__) | ||||||
|   #define WARP_SIZE 32 |   #define WARP_SIZE 64 | ||||||
| #else | #else | ||||||
|   #define WARP_SIZE warpSize |   #define WARP_SIZE 32 | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| #ifndef USE_ROCM | #ifndef USE_ROCM | ||||||
|  | |||||||
							
								
								
									
										114
									
								
								csrc/custom_quickreduce.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								csrc/custom_quickreduce.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,114 @@ | |||||||
|  | #include <ATen/cuda/Exceptions.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include <c10/cuda/CUDAStream.h> | ||||||
|  | #include <torch/all.h> | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |  | ||||||
|  |   #include "quickreduce/quick_reduce.h" | ||||||
|  |  | ||||||
|  | quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, | ||||||
|  |                                    std::optional<int64_t> qr_max_size) { | ||||||
|  |   if (world_size > 8) | ||||||
|  |     throw std::invalid_argument("world size > 8 is not supported"); | ||||||
|  |   if (world_size == 6) | ||||||
|  |     throw std::invalid_argument("world size == 6 is not supported"); | ||||||
|  |   if (world_size % 2 != 0) | ||||||
|  |     throw std::invalid_argument("Odd num gpus is not supported for now"); | ||||||
|  |   if (rank < 0 || rank >= world_size) | ||||||
|  |     throw std::invalid_argument("invalid rank passed in"); | ||||||
|  |   quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); | ||||||
|  |   fptr->init(world_size, rank, qr_max_size); | ||||||
|  |   return (quickreduce::fptr_t)fptr; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void qr_destroy(quickreduce::fptr_t _fa) { | ||||||
|  |   if (_fa) { | ||||||
|  |     auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||||
|  |     fa->destroy(); | ||||||
|  |     delete fa; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { | ||||||
|  |   auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||||
|  |   hipIpcMemHandle_t handle = fa->get_handle(); | ||||||
|  |   auto options = | ||||||
|  |       torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); | ||||||
|  |   auto data_handle = | ||||||
|  |       torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options); | ||||||
|  |   std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); | ||||||
|  |   return data_handle; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void qr_open_handles(quickreduce::fptr_t _fa, | ||||||
|  |                      const std::vector<torch::Tensor>& handles) { | ||||||
|  |   auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||||
|  |   std::vector<hipIpcMemHandle_t> ipc_handles; | ||||||
|  |   ipc_handles.reserve(handles.size()); | ||||||
|  |   for (auto& handle : handles) { | ||||||
|  |     // Ensure the tensor is on the same device as the current device. | ||||||
|  |     hipIpcMemHandle_t ipc_handle; | ||||||
|  |     std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); | ||||||
|  |     ipc_handles.push_back(ipc_handle); | ||||||
|  |   } | ||||||
|  |   fa->open_ipc_handles(ipc_handles); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, | ||||||
|  |                    torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { | ||||||
|  |   auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); | ||||||
|  |   auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); | ||||||
|  |  | ||||||
|  |   TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); | ||||||
|  |   TORCH_CHECK_EQ(inp.numel(), out.numel()); | ||||||
|  |   TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); | ||||||
|  |   if (out.scalar_type() == at::ScalarType::Half) { | ||||||
|  |     fa->allreduce<half, false>(reinterpret_cast<half*>(inp.data_ptr()), | ||||||
|  |                                reinterpret_cast<half*>(out.data_ptr()), | ||||||
|  |                                out.numel(), quant_level, stream); | ||||||
|  |   } else if (out.scalar_type() == at::ScalarType::BFloat16) { | ||||||
|  |     if (cast_bf2half) { | ||||||
|  |       fa->allreduce<half, true>(reinterpret_cast<half*>(inp.data_ptr()), | ||||||
|  |                                 reinterpret_cast<half*>(out.data_ptr()), | ||||||
|  |                                 out.numel(), quant_level, stream); | ||||||
|  |     } else { | ||||||
|  |       fa->allreduce<quickreduce::nv_bfloat16, false>( | ||||||
|  |           reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()), | ||||||
|  |           reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()), | ||||||
|  |           out.numel(), quant_level, stream); | ||||||
|  |     } | ||||||
|  |   } else { | ||||||
|  |     throw std::runtime_error( | ||||||
|  |         "quick allreduce only supports float16 and bfloat16"); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int64_t qr_max_size() { | ||||||
|  |   // The default is 2GB (2,147,483,648 bytes) | ||||||
|  |   return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |   #define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half)       \ | ||||||
|  |     template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, \ | ||||||
|  |                                                   cast_bf2half>;  \ | ||||||
|  |     template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, \ | ||||||
|  |                                                   cast_bf2half>;  \ | ||||||
|  |     template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>; | ||||||
|  |  | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) | ||||||
|  |  | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) | ||||||
|  | INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) | ||||||
|  |  | ||||||
|  | #endif  // USE_ROCM | ||||||
| @ -153,7 +153,7 @@ struct ScaledEpilogueBias | |||||||
|       cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>; |       cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>; | ||||||
|  |  | ||||||
|   using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< |   using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< | ||||||
|       cutlass::multiply_add, ElementD, float, |       cutlass::homogeneous_multiply_add, ElementD, float, | ||||||
|       cutlass::FloatRoundStyle::round_to_nearest>; |       cutlass::FloatRoundStyle::round_to_nearest>; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
| @ -210,7 +210,7 @@ struct ScaledEpilogueBiasAzp | |||||||
|                                               EVTComputeAzp>; |                                               EVTComputeAzp>; | ||||||
|  |  | ||||||
|   using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< |   using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< | ||||||
|       cutlass::multiply_add, ElementD, float, |       cutlass::homogeneous_multiply_add, ElementD, float, | ||||||
|       cutlass::FloatRoundStyle::round_to_nearest>; |       cutlass::FloatRoundStyle::round_to_nearest>; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
| @ -288,7 +288,7 @@ struct ScaledEpilogueBiasAzpToken | |||||||
|                                               EVTComputeAcc>; |                                               EVTComputeAcc>; | ||||||
|  |  | ||||||
|   using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< |   using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< | ||||||
|       cutlass::multiply_add, ElementD, float, |       cutlass::homogeneous_multiply_add, ElementD, float, | ||||||
|       cutlass::FloatRoundStyle::round_to_nearest>; |       cutlass::FloatRoundStyle::round_to_nearest>; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
|  | |||||||
| @ -195,7 +195,7 @@ struct ScaledEpilogueBias | |||||||
|       cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>; |       cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>; | ||||||
|  |  | ||||||
|   using Compute1 = cutlass::epilogue::fusion::Sm90Compute< |   using Compute1 = cutlass::epilogue::fusion::Sm90Compute< | ||||||
|       cutlass::multiply_add, ElementD, float, |       cutlass::homogeneous_multiply_add, ElementD, float, | ||||||
|       cutlass::FloatRoundStyle::round_to_nearest>; |       cutlass::FloatRoundStyle::round_to_nearest>; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
| @ -238,7 +238,7 @@ struct ScaledEpilogueColumnBias | |||||||
|       cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>; |       cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>; | ||||||
|  |  | ||||||
|   using Compute1 = cutlass::epilogue::fusion::Sm90Compute< |   using Compute1 = cutlass::epilogue::fusion::Sm90Compute< | ||||||
|       cutlass::multiply_add, ElementD, float, |       cutlass::homogeneous_multiply_add, ElementD, float, | ||||||
|       cutlass::FloatRoundStyle::round_to_nearest>; |       cutlass::FloatRoundStyle::round_to_nearest>; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
| @ -295,7 +295,7 @@ struct ScaledEpilogueBiasAzp | |||||||
|       cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>; |       cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>; | ||||||
|  |  | ||||||
|   using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< |   using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< | ||||||
|       cutlass::multiply_add, ElementD, float, |       cutlass::homogeneous_multiply_add, ElementD, float, | ||||||
|       cutlass::FloatRoundStyle::round_to_nearest>; |       cutlass::FloatRoundStyle::round_to_nearest>; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
| @ -371,7 +371,7 @@ struct ScaledEpilogueBiasAzpToken | |||||||
|       cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>; |       cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>; | ||||||
|  |  | ||||||
|   using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< |   using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< | ||||||
|       cutlass::multiply_add, ElementD, float, |       cutlass::homogeneous_multiply_add, ElementD, float, | ||||||
|       cutlass::FloatRoundStyle::round_to_nearest>; |       cutlass::FloatRoundStyle::round_to_nearest>; | ||||||
|  |  | ||||||
|  public: |  public: | ||||||
|  | |||||||
| @ -45,7 +45,6 @@ | |||||||
| #include "cute/algorithm/functional.hpp" | #include "cute/algorithm/functional.hpp" | ||||||
| #include "cute/atom/mma_atom.hpp" | #include "cute/atom/mma_atom.hpp" | ||||||
| #include "cute/algorithm/gemm.hpp" | #include "cute/algorithm/gemm.hpp" | ||||||
| #include "cute/tensor_predicate.hpp" |  | ||||||
| #include "cute/numeric/arithmetic_tuple.hpp" | #include "cute/numeric/arithmetic_tuple.hpp" | ||||||
|  |  | ||||||
| #include "cutlass_extensions/gemm/dispatch_policy.hpp" | #include "cutlass_extensions/gemm/dispatch_policy.hpp" | ||||||
|  | |||||||
| @ -1,660 +0,0 @@ | |||||||
| // clang-format off |  | ||||||
| // adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu  |  | ||||||
| // and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu |  | ||||||
| #include <torch/all.h> |  | ||||||
| #include <ATen/cuda/CUDAContext.h> |  | ||||||
| #include <c10/cuda/CUDAGuard.h> |  | ||||||
|  |  | ||||||
| #include "causal_conv1d.h" |  | ||||||
| #include <c10/util/BFloat16.h> |  | ||||||
| #include <c10/util/Half.h> |  | ||||||
| #include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK |  | ||||||
|  |  | ||||||
| #include <cub/block/block_load.cuh> |  | ||||||
| #include <cub/block/block_store.cuh> |  | ||||||
|  |  | ||||||
| #ifdef USE_ROCM |  | ||||||
|     namespace cub = hipcub; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| #include "static_switch.h" |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") |  | ||||||
|  |  | ||||||
| #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...)              \ |  | ||||||
|     if (ITYPE == at::ScalarType::Half) {                                            \ |  | ||||||
|         using input_t = at::Half;                                                   \ |  | ||||||
|         using weight_t = at::Half;                                                  \ |  | ||||||
|         __VA_ARGS__();                                                              \ |  | ||||||
|     } else if (ITYPE == at::ScalarType::BFloat16) {                                 \ |  | ||||||
|         using input_t = at::BFloat16;                                               \ |  | ||||||
|         using weight_t = at::BFloat16;                                              \ |  | ||||||
|         __VA_ARGS__();                                                              \ |  | ||||||
|     } else if (ITYPE == at::ScalarType::Float)  {                                   \ |  | ||||||
|         using input_t = float;                                                      \ |  | ||||||
|         using weight_t = float;                                                     \ |  | ||||||
|         __VA_ARGS__();                                                              \ |  | ||||||
|     } else {                                                                        \ |  | ||||||
|         AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| template<typename input_t, typename weight_t> |  | ||||||
| void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
|  |  | ||||||
| template<typename input_t, typename weight_t> |  | ||||||
| void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
|  |  | ||||||
| void set_conv_params_fwd(ConvParamsBase ¶ms, |  | ||||||
|                          // sizes |  | ||||||
|                          const size_t batch, |  | ||||||
|                          const size_t dim, |  | ||||||
|                          const size_t seqlen, |  | ||||||
|                          const size_t width, |  | ||||||
|                          // device pointers |  | ||||||
|                          const at::Tensor x, |  | ||||||
|                          const at::Tensor weight, |  | ||||||
|                          const at::Tensor out, |  | ||||||
|                          const std::optional<at::Tensor>& bias, |  | ||||||
|                          bool silu_activation, |  | ||||||
|                          int64_t pad_slot_id, |  | ||||||
|                          const std::optional<at::Tensor>& query_start_loc = std::nullopt, |  | ||||||
|                          const std::optional<at::Tensor>& cache_indices = std::nullopt, |  | ||||||
|                          const std::optional<at::Tensor>& has_initial_state = std::nullopt) { |  | ||||||
|  |  | ||||||
|     // Reset the parameters |  | ||||||
|     memset(¶ms, 0, sizeof(params)); |  | ||||||
|  |  | ||||||
|     params.batch = batch; |  | ||||||
|     params.dim = dim; |  | ||||||
|     params.seqlen = seqlen; |  | ||||||
|     params.width = width; |  | ||||||
|     params.pad_slot_id = pad_slot_id; |  | ||||||
|  |  | ||||||
|     params.silu_activation = silu_activation; |  | ||||||
|  |  | ||||||
|     // Set the pointers and strides. |  | ||||||
|     params.x_ptr = x.data_ptr(); |  | ||||||
|     params.weight_ptr = weight.data_ptr(); |  | ||||||
|     params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; |  | ||||||
|     params.out_ptr = out.data_ptr(); |  | ||||||
|     // All stride are in elements, not bytes. |  | ||||||
|     params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; |  | ||||||
|     params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; |  | ||||||
|     params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; |  | ||||||
|     const bool varlen = params.query_start_loc_ptr != nullptr; |  | ||||||
|     params.x_batch_stride = x.stride(varlen ? 1 : 0); |  | ||||||
|     params.x_c_stride = x.stride(varlen ? 0 : 1); |  | ||||||
|     params.x_l_stride = x.stride(varlen ? 1 : -1); |  | ||||||
|     params.weight_c_stride = weight.stride(0); |  | ||||||
|     params.weight_width_stride = weight.stride(1); |  | ||||||
|     params.out_batch_stride = out.stride(varlen ? 1 : 0); |  | ||||||
|     params.out_c_stride = out.stride(varlen ? 0 : 1); |  | ||||||
|     params.out_l_stride = out.stride(varlen ? 1 : -1); |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, |  | ||||||
|                   const std::optional<at::Tensor> &bias_, |  | ||||||
|                   const std::optional<at::Tensor> &conv_states, |  | ||||||
|                   const std::optional<at::Tensor> &query_start_loc, |  | ||||||
|                   const std::optional<at::Tensor> &cache_indices, |  | ||||||
|                   const std::optional<at::Tensor> &has_initial_state, |  | ||||||
|                   bool silu_activation, |  | ||||||
|                  // used to identify padding entries if cache_indices provided |  | ||||||
|                  // in case of padding, the kernel will return early |  | ||||||
|                   int64_t pad_slot_id) { |  | ||||||
|     auto input_type = x.scalar_type(); |  | ||||||
|     auto weight_type = weight.scalar_type(); |  | ||||||
|     TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); |  | ||||||
|     TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); |  | ||||||
|  |  | ||||||
|     TORCH_CHECK(x.is_cuda()); |  | ||||||
|     TORCH_CHECK(weight.is_cuda()); |  | ||||||
|      |  | ||||||
|     const bool varlen = query_start_loc.has_value() ? true : false; |  | ||||||
|     const auto sizes = x.sizes(); |  | ||||||
|     const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; |  | ||||||
|     const int dim = varlen ? sizes[0] : sizes[1]; |  | ||||||
|     const int seqlen = varlen ? sizes[1] : sizes[2]; |  | ||||||
|     const int width = weight.size(-1); |  | ||||||
|     if (varlen){ |  | ||||||
|         CHECK_SHAPE(x, dim, seqlen); |  | ||||||
|     } |  | ||||||
|     else { |  | ||||||
|         CHECK_SHAPE(x, batch_size, dim, seqlen); |  | ||||||
|     } |  | ||||||
|     CHECK_SHAPE(weight, dim, width); |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     if (bias_.has_value()) { |  | ||||||
|         auto bias = bias_.value(); |  | ||||||
|         TORCH_CHECK(bias.scalar_type() == weight_type); |  | ||||||
|         TORCH_CHECK(bias.is_cuda()); |  | ||||||
|         TORCH_CHECK(bias.stride(-1) == 1); |  | ||||||
|         CHECK_SHAPE(bias, dim); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     if (has_initial_state.has_value()) { |  | ||||||
|         auto has_initial_state_ = has_initial_state.value(); |  | ||||||
|         TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); |  | ||||||
|         TORCH_CHECK(has_initial_state_.is_cuda()); |  | ||||||
|         CHECK_SHAPE(has_initial_state_, batch_size); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     if (query_start_loc.has_value()) { |  | ||||||
|         auto query_start_loc_ = query_start_loc.value(); |  | ||||||
|         TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); |  | ||||||
|         TORCH_CHECK(query_start_loc_.is_cuda()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     if (cache_indices.has_value()) { |  | ||||||
|         auto cache_indices_ = cache_indices.value(); |  | ||||||
|         TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); |  | ||||||
|         TORCH_CHECK(cache_indices_.is_cuda()); |  | ||||||
|         CHECK_SHAPE(cache_indices_, batch_size); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     at::Tensor out = x; |  | ||||||
|  |  | ||||||
|     ConvParamsBase params; |  | ||||||
|     set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, |  | ||||||
|                         bias_, |  | ||||||
|                         silu_activation,  |  | ||||||
|                         pad_slot_id, |  | ||||||
|                         query_start_loc, |  | ||||||
|                         cache_indices, |  | ||||||
|                         has_initial_state |  | ||||||
|                         ); |  | ||||||
|  |  | ||||||
|     if (conv_states.has_value()) { |  | ||||||
|         auto conv_states_ = conv_states.value(); |  | ||||||
|         TORCH_CHECK(conv_states_.scalar_type() == input_type); |  | ||||||
|         TORCH_CHECK(conv_states_.is_cuda()); |  | ||||||
|         params.conv_states_ptr = conv_states_.data_ptr(); |  | ||||||
|         params.conv_states_batch_stride = conv_states_.stride(0); |  | ||||||
|         params.conv_states_c_stride = conv_states_.stride(1); |  | ||||||
|         params.conv_states_l_stride = conv_states_.stride(2); |  | ||||||
|     } else { |  | ||||||
|         params.conv_states_ptr = nullptr; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // Otherwise the kernel will be launched from cuda:0 device |  | ||||||
|     // Cast to char to avoid compiler warning about narrowing |  | ||||||
|     at::cuda::CUDAGuard device_guard{(char)x.get_device()}; |  | ||||||
|     auto stream = at::cuda::getCurrentCUDAStream().stream(); |  | ||||||
|     DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { |  | ||||||
|             causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream); |  | ||||||
|     }); |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| void causal_conv1d_update(const at::Tensor &x, |  | ||||||
|                      const at::Tensor &conv_state, |  | ||||||
|                      const at::Tensor &weight, |  | ||||||
|                      const std::optional<at::Tensor> &bias_, |  | ||||||
|                      bool silu_activation, |  | ||||||
|                      const std::optional<at::Tensor> &cache_seqlens_, |  | ||||||
|                      const std::optional<at::Tensor> &conv_state_indices_, |  | ||||||
|                      // used to identify padding entries if cache_indices provided |  | ||||||
|                      // in case of padding, the kernel will return early |  | ||||||
|                      int64_t pad_slot_id) { |  | ||||||
|     auto input_type = x.scalar_type(); |  | ||||||
|     auto weight_type = weight.scalar_type(); |  | ||||||
|     TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); |  | ||||||
|     TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); |  | ||||||
|     TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); |  | ||||||
|     TORCH_CHECK(conv_state.scalar_type() == input_type); |  | ||||||
|  |  | ||||||
|     TORCH_CHECK(x.is_cuda()); |  | ||||||
|     TORCH_CHECK(conv_state.is_cuda()); |  | ||||||
|     TORCH_CHECK(weight.is_cuda()); |  | ||||||
|  |  | ||||||
|     const auto sizes = x.sizes(); |  | ||||||
|     const int batch_size = sizes[0]; |  | ||||||
|     const int dim = sizes[1]; |  | ||||||
|     const int seqlen = sizes[2]; |  | ||||||
|     const int width = weight.size(-1); |  | ||||||
|     const int conv_state_len = conv_state.size(2); |  | ||||||
|     TORCH_CHECK(conv_state_len >= width - 1); |  | ||||||
|  |  | ||||||
|     CHECK_SHAPE(x, batch_size, dim, seqlen); |  | ||||||
|     CHECK_SHAPE(weight, dim, width); |  | ||||||
|  |  | ||||||
|     TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); |  | ||||||
|  |  | ||||||
|     if (bias_.has_value()) { |  | ||||||
|         auto bias = bias_.value(); |  | ||||||
|         TORCH_CHECK(bias.scalar_type() == weight_type); |  | ||||||
|         TORCH_CHECK(bias.is_cuda()); |  | ||||||
|         TORCH_CHECK(bias.stride(-1) == 1); |  | ||||||
|         CHECK_SHAPE(bias, dim); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     at::Tensor out = x; |  | ||||||
|  |  | ||||||
|     ConvParamsBase params; |  | ||||||
|     set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, |  | ||||||
|                         bias_, |  | ||||||
|                         silu_activation, |  | ||||||
|                         pad_slot_id); |  | ||||||
|     params.conv_state_ptr = conv_state.data_ptr(); |  | ||||||
|     params.conv_state_len = conv_state_len; |  | ||||||
|     // All stride are in elements, not bytes. |  | ||||||
|     params.conv_state_batch_stride = conv_state.stride(0); |  | ||||||
|     params.conv_state_c_stride = conv_state.stride(1); |  | ||||||
|     params.conv_state_l_stride = conv_state.stride(2); |  | ||||||
|  |  | ||||||
|     if (cache_seqlens_.has_value()) { |  | ||||||
|         auto cache_seqlens = cache_seqlens_.value(); |  | ||||||
|         TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); |  | ||||||
|         TORCH_CHECK(cache_seqlens.is_cuda()); |  | ||||||
|         TORCH_CHECK(cache_seqlens.stride(-1) == 1); |  | ||||||
|         CHECK_SHAPE(cache_seqlens, batch_size); |  | ||||||
|         params.cache_seqlens = cache_seqlens.data_ptr<int32_t>(); |  | ||||||
|     } else { |  | ||||||
|         params.cache_seqlens = nullptr; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if (conv_state_indices_.has_value()) { |  | ||||||
|         auto conv_state_indices = conv_state_indices_.value(); |  | ||||||
|         TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) |  | ||||||
|         TORCH_CHECK(conv_state_indices.is_cuda()); |  | ||||||
|         TORCH_CHECK(conv_state_indices.stride(0) == 1) |  | ||||||
|         CHECK_SHAPE(conv_state_indices, batch_size); |  | ||||||
|  |  | ||||||
|         int conv_state_entries = conv_state.size(0); |  | ||||||
|         CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); |  | ||||||
|  |  | ||||||
|         params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>(); |  | ||||||
|     } else { |  | ||||||
|         CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); |  | ||||||
|         params.conv_state_indices_ptr = nullptr; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // Otherwise the kernel will be launched from cuda:0 device |  | ||||||
|     // Cast to char to avoid compiler warning about narrowing |  | ||||||
|     at::cuda::CUDAGuard device_guard{(char)x.get_device()}; |  | ||||||
|     auto stream = at::cuda::getCurrentCUDAStream().stream(); |  | ||||||
|     DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { |  | ||||||
|             causal_conv1d_update_cuda<input_t, weight_t>(params, stream); |  | ||||||
|     }); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_> |  | ||||||
| struct Causal_conv1d_fwd_kernel_traits { |  | ||||||
|     using input_t = input_t_; |  | ||||||
|     using weight_t = weight_t_; |  | ||||||
|     static constexpr int kNThreads = kNThreads_; |  | ||||||
|     static constexpr int kWidth = kWidth_; |  | ||||||
|     static constexpr int kNBytes = sizeof(input_t); |  | ||||||
|     static_assert(kNBytes == 2 || kNBytes == 4); |  | ||||||
|     static constexpr int kNElts = kNBytes == 4 ? 4 : 8; |  | ||||||
|     static_assert(kWidth <= kNElts); |  | ||||||
|     static constexpr bool kIsVecLoad = kIsVecLoad_; |  | ||||||
|     using vec_t = typename BytesToType<kNBytes * kNElts>::Type; |  | ||||||
|     using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>; |  | ||||||
|     using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>; |  | ||||||
|     using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>; |  | ||||||
|     using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>; |  | ||||||
|     static constexpr int kSmemIOSize = kIsVecLoad |  | ||||||
|         ? 0 |  | ||||||
|         : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); |  | ||||||
|     static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; |  | ||||||
|     static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<typename Ktraits> |  | ||||||
| __global__ __launch_bounds__(Ktraits::kNThreads) |  | ||||||
| void causal_conv1d_fwd_kernel(ConvParamsBase params) { |  | ||||||
|     constexpr int kWidth = Ktraits::kWidth; |  | ||||||
|     constexpr int kNThreads = Ktraits::kNThreads; |  | ||||||
|     constexpr int kNElts = Ktraits::kNElts; |  | ||||||
|     constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; |  | ||||||
|     using input_t = typename Ktraits::input_t; |  | ||||||
|     using vec_t = typename Ktraits::vec_t; |  | ||||||
|     using weight_t = typename Ktraits::weight_t; |  | ||||||
|  |  | ||||||
|     // Shared memory. |  | ||||||
|     extern __shared__ char smem_[]; |  | ||||||
|     auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_); |  | ||||||
|     auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_); |  | ||||||
|     auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_); |  | ||||||
|     auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_); |  | ||||||
|     vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize); |  | ||||||
|  |  | ||||||
|     const bool kVarlen = params.query_start_loc_ptr != nullptr; |  | ||||||
|     const int tidx = threadIdx.x; |  | ||||||
|     const int batch_id = blockIdx.x; |  | ||||||
|     const int channel_id = blockIdx.y; |  | ||||||
|     const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr; |  | ||||||
|     const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id; |  | ||||||
|     const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen; |  | ||||||
|  |  | ||||||
|     input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride |  | ||||||
|         + channel_id * params.x_c_stride; |  | ||||||
|     weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; |  | ||||||
|     input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride |  | ||||||
|         + channel_id * params.out_c_stride; |  | ||||||
|     float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); |  | ||||||
|  |  | ||||||
|     bool has_initial_state = params.has_initial_state_ptr == nullptr ? false |  | ||||||
|         : reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id]; |  | ||||||
|  |  | ||||||
|     int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr |  | ||||||
|         : reinterpret_cast<int *>(params.cache_indices_ptr); |  | ||||||
|     int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; |  | ||||||
|     // cache_index == params.pad_slot_id is defined as padding, so we exit early |  | ||||||
|     if (cache_index == params.pad_slot_id){ |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|     input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr |  | ||||||
|         : reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; |  | ||||||
|  |  | ||||||
|     // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. |  | ||||||
|     if (tidx == 0) { |  | ||||||
|         input_t initial_state[kNElts] = {0}; |  | ||||||
|         if (has_initial_state) { |  | ||||||
|             #pragma unroll |  | ||||||
|             for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; } |  | ||||||
|         } |  | ||||||
|         smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0]; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     float weight_vals[kWidth]; |  | ||||||
|     #pragma unroll |  | ||||||
|     for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } |  | ||||||
|  |  | ||||||
|     constexpr int kChunkSize = kNThreads * kNElts; |  | ||||||
|     const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize; |  | ||||||
|     for (int chunk = 0; chunk < n_chunks; ++chunk) { |  | ||||||
|         input_t x_vals_load[2 * kNElts] = {0}; |  | ||||||
|         if constexpr(kIsVecLoad) { |  | ||||||
|             typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts); |  | ||||||
|         } else { |  | ||||||
|             __syncthreads(); |  | ||||||
|             typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); |  | ||||||
|         } |  | ||||||
|         x += kChunkSize; |  | ||||||
|         __syncthreads(); |  | ||||||
|         // Thread kNThreads - 1 don't write yet, so that thread 0 can read |  | ||||||
|         // the last elements of the previous chunk. |  | ||||||
|         if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; } |  | ||||||
|         __syncthreads(); |  | ||||||
|         reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; |  | ||||||
|         __syncthreads(); |  | ||||||
|         // Now thread kNThreads - 1 can write the last elements of the current chunk. |  | ||||||
|         if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; } |  | ||||||
|  |  | ||||||
|         float x_vals[2 * kNElts]; |  | ||||||
|         #pragma unroll |  | ||||||
|         for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } |  | ||||||
|  |  | ||||||
|         float out_vals[kNElts]; |  | ||||||
|         #pragma unroll |  | ||||||
|         for (int i = 0; i < kNElts; ++i) { |  | ||||||
|             out_vals[i] = bias_val; |  | ||||||
|             #pragma unroll |  | ||||||
|             for (int w = 0; w < kWidth; ++w) { |  | ||||||
|                 out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         if (params.silu_activation) { |  | ||||||
|             #pragma unroll |  | ||||||
|             for (int i = 0; i < kNElts; ++i) { |  | ||||||
|                 out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         input_t out_vals_store[kNElts]; |  | ||||||
|         #pragma unroll |  | ||||||
|         for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } |  | ||||||
|         if constexpr(kIsVecLoad) { |  | ||||||
|             typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts); |  | ||||||
|         } else { |  | ||||||
|             typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); |  | ||||||
|         } |  | ||||||
|         out += kChunkSize; |  | ||||||
|  |  | ||||||
|         int final_state_position =  ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); |  | ||||||
|         // in case the final state is separated between the last "smem_exchange" and  |  | ||||||
|         // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),  |  | ||||||
|         // (which occurs when `final_state_position` is a non-positive index) |  | ||||||
|         // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it |  | ||||||
|         if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ |  | ||||||
|             input_t vals_load[kNElts] = {0}; |  | ||||||
|             if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ |  | ||||||
|                 // chunk = n_chunks - 2, a segment of the final state sits in the last index |  | ||||||
|                 reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1]; |  | ||||||
|                 #pragma unroll |  | ||||||
|                 for (int w = 0; w < -final_state_position; ++w){ |  | ||||||
|                     conv_states[w] = vals_load[kNElts + final_state_position + w]; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             if ((chunk == n_chunks - 1) && tidx == 0){ |  | ||||||
|                 // chunk = n_chunks - 1, the second segment of the final state first positions |  | ||||||
|                 reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0]; |  | ||||||
|                 for (int w = -final_state_position; w < kWidth - 1; ++w){ |  | ||||||
|                     conv_states[w] = vals_load[w + final_state_position]; |  | ||||||
|                 } |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     // Final state is stored in the smem_exchange last token slot, |  | ||||||
|     // in case seqlen < kWidth, we would need to take the final state from the  |  | ||||||
|     // initial state which is stored in conv_states |  | ||||||
|     // in case seqlen > kWidth, we would need to load the last kWidth - 1 data |  | ||||||
|     // and load it into conv_state accordingly |  | ||||||
|     int last_thread =  ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; |  | ||||||
|     if (conv_states != nullptr && tidx == last_thread) {  |  | ||||||
|         input_t x_vals_load[kNElts * 2] = {0}; |  | ||||||
|         // in case we are on the first kWidth tokens |  | ||||||
|         if (last_thread == 0 && seqlen < kWidth){ |  | ||||||
|             // Need to take the initial state |  | ||||||
|             reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0]; |  | ||||||
|             const int offset = seqlen - (kWidth - 1); |  | ||||||
|             #pragma unroll |  | ||||||
|             for (int w = 0; w < kWidth - 1; ++w){ |  | ||||||
|                 // pad the existing state |  | ||||||
|                 if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; } |  | ||||||
|                 else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); } |  | ||||||
|             } |  | ||||||
|             #pragma unroll |  | ||||||
|             for (int w = 0; w < kWidth - 1; ++w){ |  | ||||||
|                 if (offset + w >= 0)  |  | ||||||
|                     conv_states[w] = x_vals_load[offset + w ]; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         else { |  | ||||||
|             // in case the final state is in between the threads data |  | ||||||
|             const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); |  | ||||||
|             if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){ |  | ||||||
|                 // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a  |  | ||||||
|                 // illegal access error on H100. |  | ||||||
|                 // Therefore, we access last_thread + 1, only if the final state data sits there |  | ||||||
|                 reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1]; |  | ||||||
|             } |  | ||||||
|             reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread]; |  | ||||||
|             #pragma unroll |  | ||||||
|             for (int w = 0; w < kWidth - 1; ++w){ |  | ||||||
|                 conv_states[w] = x_vals_load[offset + w ]; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|          |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| template<int kNThreads, int kWidth, typename input_t, typename weight_t> |  | ||||||
| void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { |  | ||||||
|     static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; |  | ||||||
|     const bool kVarlen = params.query_start_loc_ptr != nullptr; |  | ||||||
|     BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] { |  | ||||||
|         using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>; |  | ||||||
|         constexpr int kSmemSize = Ktraits::kSmemSize; |  | ||||||
|         dim3 grid(params.batch, params.dim); |  | ||||||
|  |  | ||||||
|         auto kernel = &causal_conv1d_fwd_kernel<Ktraits>; |  | ||||||
|  |  | ||||||
|         if (kSmemSize >= 48 * 1024) { |  | ||||||
|             C10_CUDA_CHECK(cudaFuncSetAttribute( |  | ||||||
|                 (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); |  | ||||||
|             std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; |  | ||||||
|         } |  | ||||||
|         kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); |  | ||||||
|  |  | ||||||
|         C10_CUDA_KERNEL_LAUNCH_CHECK(); |  | ||||||
|     }); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template<typename input_t, typename weight_t> |  | ||||||
| void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { |  | ||||||
|     if (params.width == 2) { |  | ||||||
|         causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); |  | ||||||
|     } else if (params.width == 3) { |  | ||||||
|         causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); |  | ||||||
|     } else if (params.width == 4) { |  | ||||||
|         causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  |  | ||||||
| template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
| template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
| template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_> |  | ||||||
| struct Causal_conv1d_update_kernel_traits { |  | ||||||
|     using input_t = input_t_; |  | ||||||
|     using weight_t = weight_t_; |  | ||||||
|     static constexpr int kNThreads = kNThreads_; |  | ||||||
|     static constexpr int kWidth = kWidth_; |  | ||||||
|     static constexpr int kNBytes = sizeof(input_t); |  | ||||||
|     static_assert(kNBytes == 2 || kNBytes == 4); |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<typename Ktraits, bool kIsCircularBuffer> |  | ||||||
| __global__ __launch_bounds__(Ktraits::kNThreads) |  | ||||||
| void causal_conv1d_update_kernel(ConvParamsBase params) { |  | ||||||
|     constexpr int kWidth = Ktraits::kWidth; |  | ||||||
|     constexpr int kNThreads = Ktraits::kNThreads; |  | ||||||
|     using input_t = typename Ktraits::input_t; |  | ||||||
|     using weight_t = typename Ktraits::weight_t; |  | ||||||
|  |  | ||||||
|     const int tidx = threadIdx.x; |  | ||||||
|     const int batch_id = blockIdx.x; |  | ||||||
|     const int channel_id = blockIdx.y * kNThreads + tidx; |  | ||||||
|     if (channel_id >= params.dim) return; |  | ||||||
|  |  | ||||||
|     input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride |  | ||||||
|         + channel_id * params.x_c_stride; |  | ||||||
|  |  | ||||||
|     // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor |  | ||||||
|     // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. |  | ||||||
|     const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr |  | ||||||
|         ? batch_id |  | ||||||
|         : params.conv_state_indices_ptr[batch_id]; |  | ||||||
|     // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early |  | ||||||
|     if (conv_state_batch_coord == params.pad_slot_id){ |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|     input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)  |  | ||||||
|         + conv_state_batch_coord * params.conv_state_batch_stride |  | ||||||
|         + channel_id * params.conv_state_c_stride; |  | ||||||
|  |  | ||||||
|     weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; |  | ||||||
|     input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride |  | ||||||
|         + channel_id * params.out_c_stride; |  | ||||||
|     float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); |  | ||||||
|  |  | ||||||
|     int state_len = params.conv_state_len; |  | ||||||
|     int advance_len = params.seqlen; |  | ||||||
|     int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; |  | ||||||
|     int update_idx = cache_seqlen - (kWidth - 1); |  | ||||||
|     update_idx = update_idx < 0 ? update_idx + state_len : update_idx; |  | ||||||
|  |  | ||||||
|     float weight_vals[kWidth] = {0}; |  | ||||||
|     #pragma unroll |  | ||||||
|     for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } |  | ||||||
|  |  | ||||||
|     float x_vals[kWidth] = {0}; |  | ||||||
|     if constexpr (!kIsCircularBuffer) { |  | ||||||
|         #pragma unroll 2 |  | ||||||
|         for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { |  | ||||||
|             conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; |  | ||||||
|         } |  | ||||||
|         #pragma unroll |  | ||||||
|         for (int i = 0; i < kWidth - 1; ++i) { |  | ||||||
|             input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; |  | ||||||
|             if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { |  | ||||||
|                 conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; |  | ||||||
|             } |  | ||||||
|             x_vals[i] = float(state_val); |  | ||||||
|         } |  | ||||||
|     } else { |  | ||||||
|         #pragma unroll |  | ||||||
|         for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { |  | ||||||
|             input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; |  | ||||||
|             x_vals[i] = float(state_val); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     #pragma unroll 2 |  | ||||||
|     for (int i = 0; i < params.seqlen; ++i) { |  | ||||||
|         input_t x_val = x[i * params.x_l_stride]; |  | ||||||
|         if constexpr (!kIsCircularBuffer) { |  | ||||||
|             if (i < advance_len && state_len - advance_len + i >= 0) { |  | ||||||
|                 conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             conv_state[update_idx * params.conv_state_l_stride] = x_val; |  | ||||||
|             ++update_idx; |  | ||||||
|             update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; |  | ||||||
|         } |  | ||||||
|         x_vals[kWidth - 1] = float(x_val); |  | ||||||
|         float out_val = bias_val; |  | ||||||
|         #pragma unroll |  | ||||||
|         for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } |  | ||||||
|         if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } |  | ||||||
|         out[i * params.out_l_stride] = input_t(out_val); |  | ||||||
|         // Shift the input buffer by 1 |  | ||||||
|         #pragma unroll |  | ||||||
|         for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template<int kNThreads, int kWidth, typename input_t, typename weight_t> |  | ||||||
| void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { |  | ||||||
|     using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>; |  | ||||||
|     dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); |  | ||||||
|     auto kernel = params.cache_seqlens == nullptr |  | ||||||
|         ? &causal_conv1d_update_kernel<Ktraits, false> |  | ||||||
|         : &causal_conv1d_update_kernel<Ktraits, true>; |  | ||||||
|     kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params); |  | ||||||
|     C10_CUDA_KERNEL_LAUNCH_CHECK(); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template<typename input_t, typename weight_t> |  | ||||||
| void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { |  | ||||||
|     if (params.width == 2) { |  | ||||||
|         causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); |  | ||||||
|     } else if (params.width == 3) { |  | ||||||
|         causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); |  | ||||||
|     } else if (params.width == 4) { |  | ||||||
|         causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
| template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
| template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |  | ||||||
| @ -1,159 +0,0 @@ | |||||||
| /****************************************************************************** |  | ||||||
|  * Copyright (c) 2024, Tri Dao. |  | ||||||
|  ******************************************************************************/ |  | ||||||
| // clang-format off |  | ||||||
| // adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h |  | ||||||
| #pragma once |  | ||||||
|  |  | ||||||
| #include <cuda_bf16.h> |  | ||||||
| #include <cuda_fp16.h> |  | ||||||
| //////////////////////////////////////////////////////////////////////////////////////////////////// |  | ||||||
|  |  | ||||||
| struct ConvParamsBase { |  | ||||||
|     using index_t = uint32_t; |  | ||||||
|  |  | ||||||
|     int batch, dim, seqlen, width; |  | ||||||
|     int64_t pad_slot_id; |  | ||||||
|     bool silu_activation; |  | ||||||
|  |  | ||||||
|     index_t x_batch_stride; |  | ||||||
|     index_t x_c_stride; |  | ||||||
|     index_t x_l_stride; |  | ||||||
|     index_t weight_c_stride; |  | ||||||
|     index_t weight_width_stride; |  | ||||||
|     index_t out_batch_stride; |  | ||||||
|     index_t out_c_stride; |  | ||||||
|     index_t out_l_stride; |  | ||||||
|  |  | ||||||
|     int conv_state_len; |  | ||||||
|     index_t conv_state_batch_stride; |  | ||||||
|     index_t conv_state_c_stride; |  | ||||||
|     index_t conv_state_l_stride; |  | ||||||
|  |  | ||||||
|     // Common data pointers. |  | ||||||
|     void *__restrict__ x_ptr; |  | ||||||
|     void *__restrict__ weight_ptr; |  | ||||||
|     void *__restrict__ bias_ptr; |  | ||||||
|     void *__restrict__ out_ptr; |  | ||||||
|  |  | ||||||
|     void *__restrict__ conv_state_ptr; |  | ||||||
|     void *__restrict__ query_start_loc_ptr; |  | ||||||
|     void *__restrict__ has_initial_state_ptr; |  | ||||||
|     void *__restrict__ cache_indices_ptr; |  | ||||||
|     int32_t *__restrict__ cache_seqlens; |  | ||||||
|  |  | ||||||
|     // For the continuous batching case. Makes it so that the mamba state for  |  | ||||||
|     // the current batch doesn't need to be a contiguous tensor. |  | ||||||
|     int32_t *__restrict__ conv_state_indices_ptr; |  | ||||||
|  |  | ||||||
|     void *__restrict__ seq_idx_ptr; |  | ||||||
|  |  | ||||||
|     // No __restrict__ since initial_states could be the same as final_states. |  | ||||||
|     void * initial_states_ptr; |  | ||||||
|     index_t initial_states_batch_stride; |  | ||||||
|     index_t initial_states_l_stride; |  | ||||||
|     index_t initial_states_c_stride; |  | ||||||
|  |  | ||||||
|     void * final_states_ptr; |  | ||||||
|     index_t final_states_batch_stride; |  | ||||||
|     index_t final_states_l_stride; |  | ||||||
|     index_t final_states_c_stride; |  | ||||||
|  |  | ||||||
|     void *  conv_states_ptr; |  | ||||||
|     index_t conv_states_batch_stride; |  | ||||||
|     index_t conv_states_l_stride; |  | ||||||
|     index_t conv_states_c_stride; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
|  |  | ||||||
| #ifndef USE_ROCM |  | ||||||
|     #include <cuda_bf16.h> |  | ||||||
|  |  | ||||||
|     template<typename T> |  | ||||||
|     __device__ inline T shuffle_xor(T val, int offset) { |  | ||||||
|         return __shfl_xor_sync(uint32_t(-1), val, offset); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     constexpr size_t custom_max(std::initializer_list<size_t> ilist)  |  | ||||||
|     { |  | ||||||
|         return std::max(ilist); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     template<typename T> |  | ||||||
|     constexpr T constexpr_min(T a, T b) { |  | ||||||
|         return std::min(a, b); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
| #else |  | ||||||
|     #include <hip/hip_bf16.h> |  | ||||||
|  |  | ||||||
|     template<typename T> |  | ||||||
|     __device__ inline T shuffle_xor(T val, int offset) { |  | ||||||
|         return __shfl_xor(val, offset); |  | ||||||
|     } |  | ||||||
|     constexpr size_t custom_max(std::initializer_list<size_t> ilist)  |  | ||||||
|     { |  | ||||||
|         return *std::max_element(ilist.begin(), ilist.end()); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     template<typename T> |  | ||||||
|     constexpr T constexpr_min(T a, T b) { |  | ||||||
|         return a < b ? a : b; |  | ||||||
|     } |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| //////////////////////////////////////////////////////////////////////////////////////////////////// |  | ||||||
|  |  | ||||||
| template<int BYTES> struct BytesToType {}; |  | ||||||
|  |  | ||||||
| template<> struct BytesToType<16> { |  | ||||||
|     using Type = uint4; |  | ||||||
|     static_assert(sizeof(Type) == 16); |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<> struct BytesToType<8> { |  | ||||||
|     using Type = uint64_t; |  | ||||||
|     static_assert(sizeof(Type) == 8); |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<> struct BytesToType<4> { |  | ||||||
|     using Type = uint32_t; |  | ||||||
|     static_assert(sizeof(Type) == 4); |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<> struct BytesToType<2> { |  | ||||||
|     using Type = uint16_t; |  | ||||||
|     static_assert(sizeof(Type) == 2); |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<> struct BytesToType<1> { |  | ||||||
|     using Type = uint8_t; |  | ||||||
|     static_assert(sizeof(Type) == 1); |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| //////////////////////////////////////////////////////////////////////////////////////////////////// |  | ||||||
|  |  | ||||||
| template<typename T> |  | ||||||
| struct SumOp { |  | ||||||
| __device__ inline T operator()(T const & x, T const & y) { return x + y; } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<int THREADS> |  | ||||||
| struct Allreduce { |  | ||||||
|     static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); |  | ||||||
|     template<typename T, typename Operator> |  | ||||||
|     static __device__ inline T run(T x, Operator &op) { |  | ||||||
|         constexpr int OFFSET = THREADS / 2; |  | ||||||
|         x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); |  | ||||||
|         return Allreduce<OFFSET>::run(x, op); |  | ||||||
|     } |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| template<> |  | ||||||
| struct Allreduce<2> { |  | ||||||
| template<typename T, typename Operator> |  | ||||||
| static __device__ inline T run(T x, Operator &op) { |  | ||||||
|     x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); |  | ||||||
|     return x; |  | ||||||
| } |  | ||||||
| }; |  | ||||||
| @ -1,28 +0,0 @@ | |||||||
| // Inspired by |  | ||||||
| // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h |  | ||||||
| // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h |  | ||||||
| // clang-format off |  | ||||||
| // adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h |  | ||||||
|  |  | ||||||
| #pragma once |  | ||||||
|  |  | ||||||
| /// @param COND       - a boolean expression to switch by |  | ||||||
| /// @param CONST_NAME - a name given for the constexpr bool variable. |  | ||||||
| /// @param ...       - code to execute for true and false |  | ||||||
| /// |  | ||||||
| /// Usage: |  | ||||||
| /// ``` |  | ||||||
| /// BOOL_SWITCH(flag, BoolConst, [&] { |  | ||||||
| ///     some_function<BoolConst>(...); |  | ||||||
| /// }); |  | ||||||
| /// ``` |  | ||||||
| #define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \ |  | ||||||
|     [&] {                                                                            \ |  | ||||||
|         if (COND) {                                                                  \ |  | ||||||
|             static constexpr bool CONST_NAME = true;                                 \ |  | ||||||
|             return __VA_ARGS__();                                                    \ |  | ||||||
|         } else {                                                                     \ |  | ||||||
|             static constexpr bool CONST_NAME = false;                                \ |  | ||||||
|             return __VA_ARGS__();                                                    \ |  | ||||||
|         }                                                                            \ |  | ||||||
|     }() |  | ||||||
| @ -7,7 +7,11 @@ | |||||||
|  |  | ||||||
| #include <c10/util/BFloat16.h> | #include <c10/util/BFloat16.h> | ||||||
| #include <c10/util/Half.h> | #include <c10/util/Half.h> | ||||||
| #include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK | #ifdef USE_ROCM | ||||||
|  |     #include <c10/hip/HIPException.h>  // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK | ||||||
|  | #else | ||||||
|  |     #include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK | ||||||
|  | #endif | ||||||
|  |  | ||||||
| #ifndef USE_ROCM | #ifndef USE_ROCM | ||||||
|     #include <cub/block/block_load.cuh> |     #include <cub/block/block_load.cuh> | ||||||
| @ -312,19 +316,25 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { | |||||||
|     // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size |     // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size | ||||||
|     constexpr bool kIsVariableB = true; |     constexpr bool kIsVariableB = true; | ||||||
|     constexpr bool kIsVariableC = true; |     constexpr bool kIsVariableC = true; | ||||||
|     constexpr bool kHasZ = true; |  | ||||||
|     BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { |     BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { | ||||||
|         BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { |         BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { | ||||||
|             using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ,  kVarlen, input_t, weight_t>; |             BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { | ||||||
|             constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); |                 using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ,  kVarlen, input_t, weight_t>; | ||||||
|             dim3 grid(params.batch, params.dim / kNRows); |                 constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); | ||||||
|             auto kernel = &selective_scan_fwd_kernel<Ktraits>; |                 dim3 grid(params.batch, params.dim / kNRows); | ||||||
|             if (kSmemSize >= 48 * 1024) { |                 auto kernel = &selective_scan_fwd_kernel<Ktraits>; | ||||||
|                 C10_CUDA_CHECK(cudaFuncSetAttribute( |                 if (kSmemSize >= 48 * 1024) { | ||||||
|                     (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); | #ifdef USE_ROCM | ||||||
|             } |                     C10_HIP_CHECK(hipFuncSetAttribute( | ||||||
|             kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); |                         reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); | ||||||
|             C10_CUDA_KERNEL_LAUNCH_CHECK(); | #else | ||||||
|  |                     C10_CUDA_CHECK(cudaFuncSetAttribute( | ||||||
|  |                         kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); | ||||||
|  | #endif | ||||||
|  |                 } | ||||||
|  |                 kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params); | ||||||
|  |                 C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||
|  |             }); | ||||||
|         }); |         }); | ||||||
|     }); |     }); | ||||||
| } | } | ||||||
| @ -612,19 +622,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, | |||||||
|  |  | ||||||
|     at::Tensor z, out_z; |     at::Tensor z, out_z; | ||||||
|     const bool has_z = z_.has_value(); |     const bool has_z = z_.has_value(); | ||||||
|     TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") |     if (has_z) { | ||||||
|     z = z_.value(); |         z = z_.value(); | ||||||
|     TORCH_CHECK(z.scalar_type() == input_type); |         TORCH_CHECK(z.scalar_type() == input_type); | ||||||
|     TORCH_CHECK(z.is_cuda()); |         TORCH_CHECK(z.is_cuda()); | ||||||
|     TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); |         TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); | ||||||
|     if (varlen){ |         if (varlen){ | ||||||
|         CHECK_SHAPE(z, dim, seqlen); |             CHECK_SHAPE(z, dim, seqlen); | ||||||
|     } else { |         } else { | ||||||
|         CHECK_SHAPE(z, batch_size, dim, seqlen); |             CHECK_SHAPE(z, batch_size, dim, seqlen); | ||||||
|  |         } | ||||||
|  |          | ||||||
|  |         out_z = z; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     out_z = z; |  | ||||||
|  |  | ||||||
|     // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout |     // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout | ||||||
|     at::Tensor out = delta; |     at::Tensor out = delta; | ||||||
|     TORCH_CHECK(ssm_states.scalar_type() == input_type); |     TORCH_CHECK(ssm_states.scalar_type() == input_type); | ||||||
| @ -647,12 +658,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, | |||||||
|                        ); |                        ); | ||||||
|  |  | ||||||
|      |      | ||||||
|     // Otherwise the kernel will be launched from cuda:0 device |     const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); | ||||||
|     // Cast to char to avoid compiler warning about narrowing |  | ||||||
|     at::cuda::CUDAGuard device_guard{(char)u.get_device()}; |  | ||||||
|     auto stream = at::cuda::getCurrentCUDAStream().stream(); |     auto stream = at::cuda::getCurrentCUDAStream().stream(); | ||||||
|     DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { |     DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { | ||||||
|         selective_scan_fwd_cuda<input_t, weight_t>(params, stream); |         selective_scan_fwd_cuda<input_t, weight_t>(params, stream); | ||||||
|     }); |     }); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
| @ -1255,8 +1255,6 @@ __global__ void Marlin( | |||||||
|     if constexpr (has_zp && !is_zp_float) { |     if constexpr (has_zp && !is_zp_float) { | ||||||
|       if (is_new_zp) { |       if (is_new_zp) { | ||||||
|         if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; |         if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; | ||||||
|         FragB frag_zp_0; |  | ||||||
|         FragB frag_zp_1; |  | ||||||
|         int zp_quant_0, zp_quant_1; |         int zp_quant_0, zp_quant_1; | ||||||
|  |  | ||||||
|         if constexpr (w_type.size_bits() == 4) { |         if constexpr (w_type.size_bits() == 4) { | ||||||
|  | |||||||
| @ -13,232 +13,45 @@ | |||||||
| namespace vllm { | namespace vllm { | ||||||
| namespace moe { | namespace moe { | ||||||
|  |  | ||||||
| namespace { |  | ||||||
| __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, |  | ||||||
|                                          int32_t col) { |  | ||||||
|   // don't worry about overflow because num_experts is relatively small |  | ||||||
|   return row * total_col + col; |  | ||||||
| } |  | ||||||
| }  // namespace |  | ||||||
|  |  | ||||||
| template <typename scalar_t, typename token_cnts_t> |  | ||||||
| __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, |  | ||||||
|                                             int32_t* sorted_token_ids, |  | ||||||
|                                             int32_t* expert_ids, |  | ||||||
|                                             int32_t* total_tokens_post_pad, |  | ||||||
|                                             int32_t num_experts, |  | ||||||
|                                             int32_t block_size, size_t numel) { |  | ||||||
|   const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); |  | ||||||
|   const size_t start_idx = threadIdx.x * tokens_per_thread; |  | ||||||
|  |  | ||||||
|   extern __shared__ int32_t shared_mem[]; |  | ||||||
|   int32_t* cumsum = shared_mem;  // 1d tensor with shape (num_experts + 1) |  | ||||||
|   token_cnts_t* tokens_cnts = |  | ||||||
|       (token_cnts_t*)(shared_mem + num_experts + |  | ||||||
|                       1);  // 2d tensor with shape (blockDim.x + 1, num_experts) |  | ||||||
|  |  | ||||||
|   for (int i = 0; i < num_experts; ++i) { |  | ||||||
|     tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * In the first step we compute token_cnts[thread_index + 1][expert_index], |  | ||||||
|    * which counts how many tokens in the token shard of thread_index are |  | ||||||
|    * assigned to expert expert_index. |  | ||||||
|    */ |  | ||||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { |  | ||||||
|     ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   __syncthreads(); |  | ||||||
|  |  | ||||||
|   // For each expert we accumulate the token counts from the different threads. |  | ||||||
|   if (threadIdx.x < num_experts) { |  | ||||||
|     tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; |  | ||||||
|     for (int i = 1; i <= blockDim.x; ++i) { |  | ||||||
|       tokens_cnts[index(num_experts, i, threadIdx.x)] += |  | ||||||
|           tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   __syncthreads(); |  | ||||||
|  |  | ||||||
|   // We accumulate the token counts of all experts in thread 0. |  | ||||||
|   if (threadIdx.x == 0) { |  | ||||||
|     cumsum[0] = 0; |  | ||||||
|     for (int i = 1; i <= num_experts; ++i) { |  | ||||||
|       cumsum[i] = cumsum[i - 1] + |  | ||||||
|                   CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], |  | ||||||
|                           block_size) * |  | ||||||
|                       block_size; |  | ||||||
|     } |  | ||||||
|     *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]); |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   __syncthreads(); |  | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * For each expert, each thread processes the tokens of the corresponding |  | ||||||
|    * blocks and stores the corresponding expert_id for each block. |  | ||||||
|    */ |  | ||||||
|   if (threadIdx.x < num_experts) { |  | ||||||
|     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; |  | ||||||
|          i += block_size) { |  | ||||||
|       expert_ids[i / block_size] = threadIdx.x; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * Each thread processes a token shard, calculating the index of each token |  | ||||||
|    * after sorting by expert number. Given the example topk_ids = |  | ||||||
|    * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, |  | ||||||
|    * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a |  | ||||||
|    * padding value(preset in python). |  | ||||||
|    */ |  | ||||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { |  | ||||||
|     int32_t expert_id = topk_ids[i]; |  | ||||||
|     /** The cumsum[expert_id] stores the starting index of the tokens that the |  | ||||||
|      * expert with expert_id needs to process, and |  | ||||||
|      * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens |  | ||||||
|      * processed by the expert with expert_id within the current thread's token |  | ||||||
|      * shard. |  | ||||||
|      */ |  | ||||||
|     int32_t rank_post_pad = |  | ||||||
|         tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + |  | ||||||
|         cumsum[expert_id]; |  | ||||||
|     sorted_token_ids[rank_post_pad] = i; |  | ||||||
|     ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // TODO(simon): this is temporarily adapted from |  | ||||||
| // https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 |  | ||||||
| // we did this to unblock Deepseek V3 but there should be a better |  | ||||||
| // implementation to manage shared memory. |  | ||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| __global__ void moe_align_block_size_global_mem_kernel( | __global__ void moe_align_block_size_kernel( | ||||||
|     scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, |     const scalar_t* __restrict__ topk_ids, | ||||||
|     int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, |     int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, | ||||||
|     int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { |     int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, | ||||||
|   const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); |     int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, | ||||||
|   const size_t start_idx = threadIdx.x * tokens_per_thread; |     size_t numel, int32_t* __restrict__ cumsum) { | ||||||
|  |   extern __shared__ int32_t shared_counts[]; | ||||||
|  |  | ||||||
|   for (int i = 0; i < num_experts; ++i) { |   const int warp_id = threadIdx.x / WARP_SIZE; | ||||||
|     tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * In the first step we compute token_cnts[thread_index + 1][expert_index], |  | ||||||
|    * which counts how many tokens in the token shard of thread_index are |  | ||||||
|    * assigned to expert expert_index. |  | ||||||
|    */ |  | ||||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { |  | ||||||
|     ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   __syncthreads(); |  | ||||||
|  |  | ||||||
|   // For each expert we accumulate the token counts from the different threads. |  | ||||||
|   if (threadIdx.x < num_experts) { |  | ||||||
|     tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; |  | ||||||
|     for (int i = 1; i <= blockDim.x; ++i) { |  | ||||||
|       tokens_cnts[index(num_experts, i, threadIdx.x)] += |  | ||||||
|           tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   __syncthreads(); |  | ||||||
|  |  | ||||||
|   // We accumulate the token counts of all experts in thread 0. |  | ||||||
|   if (threadIdx.x == 0) { |  | ||||||
|     cumsum[0] = 0; |  | ||||||
|     for (int i = 1; i <= num_experts; ++i) { |  | ||||||
|       cumsum[i] = cumsum[i - 1] + |  | ||||||
|                   CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], |  | ||||||
|                           block_size) * |  | ||||||
|                       block_size; |  | ||||||
|     } |  | ||||||
|     *total_tokens_post_pad = cumsum[num_experts]; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   __syncthreads(); |  | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * For each expert, each thread processes the tokens of the corresponding |  | ||||||
|    * blocks and stores the corresponding expert_id for each block. |  | ||||||
|    */ |  | ||||||
|   if (threadIdx.x < num_experts) { |  | ||||||
|     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; |  | ||||||
|          i += block_size) { |  | ||||||
|       expert_ids[i / block_size] = threadIdx.x; |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   /** |  | ||||||
|    * Each thread processes a token shard, calculating the index of each token |  | ||||||
|    * after sorting by expert number. Given the example topk_ids = |  | ||||||
|    * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, |  | ||||||
|    * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a |  | ||||||
|    * padding value(preset in python). |  | ||||||
|    */ |  | ||||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { |  | ||||||
|     int32_t expert_id = topk_ids[i]; |  | ||||||
|     /** The cumsum[expert_id] stores the starting index of the tokens that the |  | ||||||
|      * expert with expert_id needs to process, and |  | ||||||
|      * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens |  | ||||||
|      * processed by the expert with expert_id within the current thread's token |  | ||||||
|      * shard. |  | ||||||
|      */ |  | ||||||
|     int32_t rank_post_pad = |  | ||||||
|         tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + |  | ||||||
|         cumsum[expert_id]; |  | ||||||
|     sorted_token_ids[rank_post_pad] = i; |  | ||||||
|     ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // taken from |  | ||||||
| // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 |  | ||||||
| template <typename scalar_t> |  | ||||||
| __global__ void sgl_moe_align_block_size_kernel( |  | ||||||
|     scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, |  | ||||||
|     int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, |  | ||||||
|     int32_t block_size, size_t numel, int32_t* cumsum) { |  | ||||||
|   __shared__ int32_t shared_counts[32][8]; |  | ||||||
|  |  | ||||||
|   const int warp_id = threadIdx.x / 32; |  | ||||||
|   const int experts_per_warp = 8; |  | ||||||
|   const int my_expert_start = warp_id * experts_per_warp; |   const int my_expert_start = warp_id * experts_per_warp; | ||||||
|  |  | ||||||
|   // Initialize shared_counts for this warp's experts |  | ||||||
|   for (int i = 0; i < experts_per_warp; ++i) { |   for (int i = 0; i < experts_per_warp; ++i) { | ||||||
|     if (my_expert_start + i < num_experts) { |     if (my_expert_start + i < padded_num_experts) { | ||||||
|       shared_counts[warp_id][i] = 0; |       shared_counts[warp_id * experts_per_warp + i] = 0; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   __syncthreads(); |   __syncthreads(); | ||||||
|  |  | ||||||
|   const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); |   const size_t tid = threadIdx.x; | ||||||
|   const size_t start_idx = threadIdx.x * tokens_per_thread; |   const size_t stride = blockDim.x; | ||||||
|  |  | ||||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { |   for (size_t i = tid; i < numel; i += stride) { | ||||||
|     int expert_id = topk_ids[i]; |     int expert_id = topk_ids[i]; | ||||||
|     int warp_idx = expert_id / experts_per_warp; |     int warp_idx = expert_id / experts_per_warp; | ||||||
|     int expert_offset = expert_id % experts_per_warp; |     int expert_offset = expert_id % experts_per_warp; | ||||||
|     atomicAdd(&shared_counts[warp_idx][expert_offset], 1); |     atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|   __syncthreads(); |   __syncthreads(); | ||||||
|  |  | ||||||
|   // Single thread computes cumulative sum and total tokens |  | ||||||
|   if (threadIdx.x == 0) { |   if (threadIdx.x == 0) { | ||||||
|     cumsum[0] = 0; |     cumsum[0] = 0; | ||||||
|     for (int i = 1; i <= num_experts; ++i) { |     for (int i = 1; i <= num_experts; ++i) { | ||||||
|       int expert_count = 0; |       int expert_count = 0; | ||||||
|       int warp_idx = (i - 1) / experts_per_warp; |       int warp_idx = (i - 1) / experts_per_warp; | ||||||
|       int expert_offset = (i - 1) % experts_per_warp; |       int expert_offset = (i - 1) % experts_per_warp; | ||||||
|       expert_count = shared_counts[warp_idx][expert_offset]; |       expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; | ||||||
|  |  | ||||||
|       cumsum[i] = |       cumsum[i] = | ||||||
|           cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; |           cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; | ||||||
| @ -248,7 +61,6 @@ __global__ void sgl_moe_align_block_size_kernel( | |||||||
|  |  | ||||||
|   __syncthreads(); |   __syncthreads(); | ||||||
|  |  | ||||||
|   // Assign expert IDs to blocks |  | ||||||
|   if (threadIdx.x < num_experts) { |   if (threadIdx.x < num_experts) { | ||||||
|     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; |     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; | ||||||
|          i += block_size) { |          i += block_size) { | ||||||
| @ -257,13 +69,11 @@ __global__ void sgl_moe_align_block_size_kernel( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| // taken from |  | ||||||
| // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 |  | ||||||
| template <typename scalar_t> | template <typename scalar_t> | ||||||
| __global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, | __global__ void count_and_sort_expert_tokens_kernel( | ||||||
|                                           int32_t* sorted_token_ids, |     const scalar_t* __restrict__ topk_ids, | ||||||
|                                           int32_t* cumsum_buffer, |     int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, | ||||||
|                                           size_t numel) { |     size_t numel) { | ||||||
|   const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; |   const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|   const size_t stride = blockDim.x * gridDim.x; |   const size_t stride = blockDim.x * gridDim.x; | ||||||
|  |  | ||||||
| @ -290,132 +100,138 @@ __global__ void moe_sum_kernel( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <typename scalar_t> | ||||||
|  | __global__ void moe_align_block_size_small_batch_expert_kernel( | ||||||
|  |     const scalar_t* __restrict__ topk_ids, | ||||||
|  |     int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, | ||||||
|  |     int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, | ||||||
|  |     int32_t block_size, size_t numel) { | ||||||
|  |   const size_t tid = threadIdx.x; | ||||||
|  |   const size_t stride = blockDim.x; | ||||||
|  |  | ||||||
|  |   extern __shared__ int32_t shared_mem[]; | ||||||
|  |   int32_t* cumsum = shared_mem; | ||||||
|  |   int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); | ||||||
|  |  | ||||||
|  |   for (int i = 0; i < num_experts; ++i) { | ||||||
|  |     tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   for (size_t i = tid; i < numel; i += stride) { | ||||||
|  |     ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   __syncthreads(); | ||||||
|  |  | ||||||
|  |   if (threadIdx.x < num_experts) { | ||||||
|  |     tokens_cnts[threadIdx.x] = 0; | ||||||
|  |     for (int i = 1; i <= blockDim.x; ++i) { | ||||||
|  |       tokens_cnts[i * num_experts + threadIdx.x] += | ||||||
|  |           tokens_cnts[(i - 1) * num_experts + threadIdx.x]; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   __syncthreads(); | ||||||
|  |  | ||||||
|  |   if (threadIdx.x == 0) { | ||||||
|  |     cumsum[0] = 0; | ||||||
|  |     for (int i = 1; i <= num_experts; ++i) { | ||||||
|  |       cumsum[i] = | ||||||
|  |           cumsum[i - 1] + | ||||||
|  |           CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * | ||||||
|  |               block_size; | ||||||
|  |     } | ||||||
|  |     *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   __syncthreads(); | ||||||
|  |  | ||||||
|  |   if (threadIdx.x < num_experts) { | ||||||
|  |     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; | ||||||
|  |          i += block_size) { | ||||||
|  |       expert_ids[i / block_size] = threadIdx.x; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   for (size_t i = tid; i < numel; i += stride) { | ||||||
|  |     int32_t expert_id = topk_ids[i]; | ||||||
|  |     int32_t rank_post_pad = | ||||||
|  |         tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; | ||||||
|  |     sorted_token_ids[rank_post_pad] = i; | ||||||
|  |     ++tokens_cnts[threadIdx.x * num_experts + expert_id]; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| }  // namespace moe | }  // namespace moe | ||||||
| }  // namespace vllm | }  // namespace vllm | ||||||
|  |  | ||||||
|  | // taken from | ||||||
|  | // https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc | ||||||
| void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | ||||||
|                           int64_t block_size, torch::Tensor sorted_token_ids, |                           int64_t block_size, torch::Tensor sorted_token_ids, | ||||||
|                           torch::Tensor experts_ids, |                           torch::Tensor experts_ids, | ||||||
|                           torch::Tensor num_tokens_post_pad) { |                           torch::Tensor num_tokens_post_pad) { | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |  | ||||||
|   int device_max_shared_mem; |   int64_t padded_num_experts = | ||||||
|   auto dev = topk_ids.get_device(); |       ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; | ||||||
|   cudaDeviceGetAttribute(&device_max_shared_mem, |   int experts_per_warp = WARP_SIZE; | ||||||
|                          cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); |   int threads = 1024; | ||||||
|  |   threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; | ||||||
|   const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); |  | ||||||
|   const int32_t shared_mem_i32 = |  | ||||||
|       ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); |  | ||||||
|   const int32_t shared_mem_i16 = |  | ||||||
|       ((num_thread + 1) * num_experts) * sizeof(uint16_t) + |  | ||||||
|       (num_experts + 1) * sizeof(int32_t); |  | ||||||
|  |  | ||||||
|   bool use_global_memory = false; |  | ||||||
|   bool use_i16 = false;  // Use uint16_t for shared memory token counts |  | ||||||
|   if (shared_mem_i32 < device_max_shared_mem) { |  | ||||||
|     // Do nothing in this case. We're all set to use int32_t token counts |  | ||||||
|   } else if (shared_mem_i16 < device_max_shared_mem && |  | ||||||
|              topk_ids.numel() <= 65535) { |  | ||||||
|     // when nelements of topk_ids is smaller than 65535 (max value of uint16), |  | ||||||
|     // element value of token_cnts would also smaller than 65535, |  | ||||||
|     // so we can use uint16 as dtype of token_cnts |  | ||||||
|     use_i16 = true; |  | ||||||
|   } else { |  | ||||||
|     use_global_memory = true; |  | ||||||
|   } |  | ||||||
|  |  | ||||||
|   if (use_global_memory) { |  | ||||||
|     VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( |  | ||||||
|         topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { |  | ||||||
|           // calc needed amount of shared mem for `tokens_cnts` and `cumsum` |  | ||||||
|           // tensors |  | ||||||
|           const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); |  | ||||||
|  |  | ||||||
|           auto options_int = torch::TensorOptions() |  | ||||||
|                                  .dtype(torch::kInt) |  | ||||||
|                                  .device(topk_ids.device()); |  | ||||||
|           torch::Tensor token_cnts_buffer = |  | ||||||
|               torch::empty({(num_experts + 1) * num_experts}, options_int); |  | ||||||
|           torch::Tensor cumsum_buffer = |  | ||||||
|               torch::empty({num_experts + 1}, options_int); |  | ||||||
|  |  | ||||||
|           auto kernel = |  | ||||||
|               vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>; |  | ||||||
|           kernel<<<1, num_thread, 0, stream>>>( |  | ||||||
|               topk_ids.data_ptr<scalar_t>(), |  | ||||||
|               sorted_token_ids.data_ptr<int32_t>(), |  | ||||||
|               experts_ids.data_ptr<int32_t>(), |  | ||||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, |  | ||||||
|               topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(), |  | ||||||
|               cumsum_buffer.data_ptr<int32_t>()); |  | ||||||
|         }); |  | ||||||
|   } else if (use_i16) { |  | ||||||
|     VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( |  | ||||||
|         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { |  | ||||||
|           // set dynamic shared mem |  | ||||||
|           auto kernel = |  | ||||||
|               vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>; |  | ||||||
|           AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( |  | ||||||
|               (void*)kernel, shared_mem_i16)); |  | ||||||
|           kernel<<<1, num_thread, shared_mem_i16, stream>>>( |  | ||||||
|               topk_ids.data_ptr<scalar_t>(), |  | ||||||
|               sorted_token_ids.data_ptr<int32_t>(), |  | ||||||
|               experts_ids.data_ptr<int32_t>(), |  | ||||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, |  | ||||||
|               topk_ids.numel()); |  | ||||||
|         }); |  | ||||||
|   } else { |  | ||||||
|     VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( |  | ||||||
|         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { |  | ||||||
|           auto kernel = |  | ||||||
|               vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>; |  | ||||||
|           AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( |  | ||||||
|               (void*)kernel, shared_mem_i32)); |  | ||||||
|           kernel<<<1, num_thread, shared_mem_i32, stream>>>( |  | ||||||
|               topk_ids.data_ptr<scalar_t>(), |  | ||||||
|               sorted_token_ids.data_ptr<int32_t>(), |  | ||||||
|               experts_ids.data_ptr<int32_t>(), |  | ||||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, |  | ||||||
|               topk_ids.numel()); |  | ||||||
|         }); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, |  | ||||||
|                               int64_t block_size, |  | ||||||
|                               torch::Tensor sorted_token_ids, |  | ||||||
|                               torch::Tensor experts_ids, |  | ||||||
|                               torch::Tensor num_tokens_post_pad) { |  | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |  | ||||||
|   TORCH_CHECK(num_experts == 256, |  | ||||||
|               "sgl_moe_align_block_size kernel only supports deepseek v3."); |  | ||||||
|  |  | ||||||
|   VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( |   VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( | ||||||
|       topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { |       topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||||||
|         // calc needed amount of shared mem for `cumsum` tensors |         // calc needed amount of shared mem for `cumsum` tensors | ||||||
|         auto options_int = |         auto options_int = | ||||||
|             torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); |             torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); | ||||||
|         torch::Tensor cumsum_buffer = |         torch::Tensor cumsum_buffer = | ||||||
|             torch::zeros({num_experts + 1}, options_int); |             torch::zeros({num_experts + 1}, options_int); | ||||||
|  |         bool small_batch_expert_mode = | ||||||
|  |             (topk_ids.numel() < 1024) && (num_experts <= 64); | ||||||
|  |  | ||||||
|         auto align_kernel = |         if (small_batch_expert_mode) { | ||||||
|             vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>; |           const int32_t threads = max((int32_t)num_experts, WARP_SIZE); | ||||||
|         align_kernel<<<1, 1024, 0, stream>>>( |           const int32_t shared_mem_size = | ||||||
|             topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), |               ((threads + 1) * num_experts + (num_experts + 1)) * | ||||||
|             experts_ids.data_ptr<int32_t>(), |               sizeof(int32_t); | ||||||
|             num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, |  | ||||||
|             topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); |  | ||||||
|  |  | ||||||
|         const int block_threads = 256; |           auto small_batch_expert_kernel = | ||||||
|         const int num_blocks = |               vllm::moe::moe_align_block_size_small_batch_expert_kernel< | ||||||
|             (topk_ids.numel() + block_threads - 1) / block_threads; |                   scalar_t>; | ||||||
|         const int max_blocks = 65535; |           small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( | ||||||
|         const int actual_blocks = std::min(num_blocks, max_blocks); |               topk_ids.data_ptr<scalar_t>(), | ||||||
|         auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>; |               sorted_token_ids.data_ptr<int32_t>(), | ||||||
|         sort_kernel<<<actual_blocks, block_threads, 0, stream>>>( |               experts_ids.data_ptr<int32_t>(), | ||||||
|             topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), |               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||||||
|             cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); |               topk_ids.numel()); | ||||||
|  |         } else { | ||||||
|  |           auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>; | ||||||
|  |  | ||||||
|  |           size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); | ||||||
|  |           size_t shared_mem_size = | ||||||
|  |               num_warps * experts_per_warp * sizeof(int32_t); | ||||||
|  |  | ||||||
|  |           align_kernel<<<1, threads, shared_mem_size, stream>>>( | ||||||
|  |               topk_ids.data_ptr<scalar_t>(), | ||||||
|  |               sorted_token_ids.data_ptr<int32_t>(), | ||||||
|  |               experts_ids.data_ptr<int32_t>(), | ||||||
|  |               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, | ||||||
|  |               padded_num_experts, experts_per_warp, block_size, | ||||||
|  |               topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); | ||||||
|  |  | ||||||
|  |           const int block_threads = std::min(256, (int)threads); | ||||||
|  |           const int num_blocks = | ||||||
|  |               (topk_ids.numel() + block_threads - 1) / block_threads; | ||||||
|  |           const int max_blocks = 65535; | ||||||
|  |           const int actual_blocks = std::min(num_blocks, max_blocks); | ||||||
|  |  | ||||||
|  |           auto sort_kernel = | ||||||
|  |               vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t>; | ||||||
|  |           sort_kernel<<<actual_blocks, block_threads, 0, stream>>>( | ||||||
|  |               topk_ids.data_ptr<scalar_t>(), | ||||||
|  |               sorted_token_ids.data_ptr<int32_t>(), | ||||||
|  |               cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); | ||||||
|  |         } | ||||||
|       }); |       }); | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -423,7 +239,7 @@ void moe_sum(torch::Tensor& input,   // [num_tokens, topk, hidden_size] | |||||||
|              torch::Tensor& output)  // [num_tokens, hidden_size] |              torch::Tensor& output)  // [num_tokens, hidden_size] | ||||||
| { | { | ||||||
|   const int hidden_size = input.size(-1); |   const int hidden_size = input.size(-1); | ||||||
|   const int num_tokens = output.numel() / hidden_size; |   const auto num_tokens = output.numel() / hidden_size; | ||||||
|   const int topk = input.size(1); |   const int topk = input.size(1); | ||||||
|  |  | ||||||
|   dim3 grid(num_tokens); |   dim3 grid(num_tokens); | ||||||
|  | |||||||
| @ -12,12 +12,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | |||||||
|                           int64_t block_size, torch::Tensor sorted_token_ids, |                           int64_t block_size, torch::Tensor sorted_token_ids, | ||||||
|                           torch::Tensor experts_ids, |                           torch::Tensor experts_ids, | ||||||
|                           torch::Tensor num_tokens_post_pad); |                           torch::Tensor num_tokens_post_pad); | ||||||
|  |  | ||||||
| void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, |  | ||||||
|                               int64_t block_size, |  | ||||||
|                               torch::Tensor sorted_token_ids, |  | ||||||
|                               torch::Tensor experts_ids, |  | ||||||
|                               torch::Tensor num_tokens_post_pad); |  | ||||||
| #ifndef USE_ROCM | #ifndef USE_ROCM | ||||||
| torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, | torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, | ||||||
|                              torch::Tensor b_qweight, torch::Tensor b_scales, |                              torch::Tensor b_qweight, torch::Tensor b_scales, | ||||||
|  | |||||||
| @ -12,7 +12,7 @@ void moe_permute( | |||||||
|     const torch::Tensor& input,                      // [n_token, hidden] |     const torch::Tensor& input,                      // [n_token, hidden] | ||||||
|     const torch::Tensor& topk_weights,               //[n_token, topk] |     const torch::Tensor& topk_weights,               //[n_token, topk] | ||||||
|     torch::Tensor& topk_ids,                         // [n_token, topk] |     torch::Tensor& topk_ids,                         // [n_token, topk] | ||||||
|     const torch::Tensor& token_expert_indicies,      // [n_token, topk] |     const torch::Tensor& token_expert_indices,       // [n_token, topk] | ||||||
|     const std::optional<torch::Tensor>& expert_map,  // [n_expert] |     const std::optional<torch::Tensor>& expert_map,  // [n_expert] | ||||||
|     int64_t n_expert, int64_t n_local_expert, int64_t topk, |     int64_t n_expert, int64_t n_local_expert, int64_t topk, | ||||||
|     const std::optional<int64_t>& align_block_size, |     const std::optional<int64_t>& align_block_size, | ||||||
| @ -27,15 +27,15 @@ void moe_permute( | |||||||
|               "expert_first_token_offset must be int64"); |               "expert_first_token_offset must be int64"); | ||||||
|   TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, |   TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, | ||||||
|               "topk_ids must be int32"); |               "topk_ids must be int32"); | ||||||
|   TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, |   TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, | ||||||
|               "token_expert_indicies must be int32"); |               "token_expert_indices must be int32"); | ||||||
|   TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, |   TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, | ||||||
|               "src_row_id2dst_row_id_map must be int32"); |               "src_row_id2dst_row_id_map must be int32"); | ||||||
|   TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, |   TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, | ||||||
|               "expert_first_token_offset shape != n_local_expert+1") |               "expert_first_token_offset shape != n_local_expert+1") | ||||||
|   TORCH_CHECK( |   TORCH_CHECK( | ||||||
|       src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(), |       src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), | ||||||
|       "token_expert_indicies shape must be same as src_row_id2dst_row_id_map"); |       "token_expert_indices shape must be same as src_row_id2dst_row_id_map"); | ||||||
|   auto n_token = input.sizes()[0]; |   auto n_token = input.sizes()[0]; | ||||||
|   auto n_hidden = input.sizes()[1]; |   auto n_hidden = input.sizes()[1]; | ||||||
|   auto align_block_size_value = |   auto align_block_size_value = | ||||||
| @ -71,7 +71,7 @@ void moe_permute( | |||||||
|                              expert_map_ptr, n_expert, stream); |                              expert_map_ptr, n_expert, stream); | ||||||
|   } |   } | ||||||
|   // expert sort topk expert id and scan expert id get expert_first_token_offset |   // expert sort topk expert id and scan expert id get expert_first_token_offset | ||||||
|   sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies), |   sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices), | ||||||
|                     get_ptr<int>(permuted_experts_id), |                     get_ptr<int>(permuted_experts_id), | ||||||
|                     get_ptr<int>(dst_row_id2src_row_id_map), |                     get_ptr<int>(dst_row_id2src_row_id_map), | ||||||
|                     get_ptr<int64_t>(expert_first_token_offset), n_token, |                     get_ptr<int64_t>(expert_first_token_offset), n_token, | ||||||
| @ -190,7 +190,7 @@ void shuffle_rows(const torch::Tensor& input_tensor, | |||||||
|  |  | ||||||
| void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, | void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, | ||||||
|                  torch::Tensor& topk_ids, |                  torch::Tensor& topk_ids, | ||||||
|                  const torch::Tensor& token_expert_indicies, |                  const torch::Tensor& token_expert_indices, | ||||||
|                  const std::optional<torch::Tensor>& expert_map, |                  const std::optional<torch::Tensor>& expert_map, | ||||||
|                  int64_t n_expert, int64_t n_local_expert, int64_t topk, |                  int64_t n_expert, int64_t n_local_expert, int64_t topk, | ||||||
|                  const std::optional<int64_t>& align_block_size, |                  const std::optional<int64_t>& align_block_size, | ||||||
| @ -203,7 +203,7 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, | |||||||
|  |  | ||||||
| void moe_unpermute(const torch::Tensor& input, | void moe_unpermute(const torch::Tensor& input, | ||||||
|                    const torch::Tensor& topk_weights, torch::Tensor& topk_ids, |                    const torch::Tensor& topk_weights, torch::Tensor& topk_ids, | ||||||
|                    const torch::Tensor& token_expert_indicies, |                    const torch::Tensor& token_expert_indices, | ||||||
|                    const std::optional<torch::Tensor>& expert_map, |                    const std::optional<torch::Tensor>& expert_map, | ||||||
|                    int64_t n_expert, int64_t n_local_expert, int64_t topk, |                    int64_t n_expert, int64_t n_local_expert, int64_t topk, | ||||||
|                    const std::optional<int64_t>& align_block_size, |                    const std::optional<int64_t>& align_block_size, | ||||||
|  | |||||||
| @ -20,7 +20,6 @@ __global__ void expandInputRowsKernel( | |||||||
|   int expert_id = sorted_experts[expanded_dest_row]; |   int expert_id = sorted_experts[expanded_dest_row]; | ||||||
|  |  | ||||||
|   extern __shared__ int64_t smem_expert_first_token_offset[]; |   extern __shared__ int64_t smem_expert_first_token_offset[]; | ||||||
|   int64_t align_expanded_row_accumulate = 0; |  | ||||||
|   if constexpr (ALIGN_BLOCK_SIZE) { |   if constexpr (ALIGN_BLOCK_SIZE) { | ||||||
|     // load g2s |     // load g2s | ||||||
|     for (int idx = threadIdx.x; idx < num_local_experts + 1; |     for (int idx = threadIdx.x; idx < num_local_experts + 1; | ||||||
| @ -63,7 +62,6 @@ __global__ void expandInputRowsKernel( | |||||||
|     using DataElem = cutlass::Array<T, ELEM_PER_THREAD>; |     using DataElem = cutlass::Array<T, ELEM_PER_THREAD>; | ||||||
|  |  | ||||||
|     // Duplicate and permute rows |     // Duplicate and permute rows | ||||||
|     int64_t const source_k_rank = expanded_source_row / num_rows; |  | ||||||
|     int64_t const source_row = expanded_source_row % num_rows; |     int64_t const source_row = expanded_source_row % num_rows; | ||||||
|  |  | ||||||
|     auto const* source_row_ptr = |     auto const* source_row_ptr = | ||||||
| @ -160,7 +158,6 @@ __global__ void finalizeMoeRoutingKernel( | |||||||
|        elem_index += stride) { |        elem_index += stride) { | ||||||
|     ComputeElem thread_output; |     ComputeElem thread_output; | ||||||
|     thread_output.fill(0); |     thread_output.fill(0); | ||||||
|     float row_rescale{0.f}; |  | ||||||
|     for (int k_idx = 0; k_idx < k; ++k_idx) { |     for (int k_idx = 0; k_idx < k; ++k_idx) { | ||||||
|       int64_t const expanded_original_row = original_row + k_idx * num_rows; |       int64_t const expanded_original_row = original_row + k_idx * num_rows; | ||||||
|       int64_t const expanded_permuted_row = |       int64_t const expanded_permuted_row = | ||||||
| @ -177,8 +174,6 @@ __global__ void finalizeMoeRoutingKernel( | |||||||
|       auto const* expanded_permuted_rows_row_ptr = |       auto const* expanded_permuted_rows_row_ptr = | ||||||
|           expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; |           expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; | ||||||
|  |  | ||||||
|       int64_t const expert_idx = expert_for_source_row[k_offset]; |  | ||||||
|  |  | ||||||
|       ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>( |       ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>( | ||||||
|           expanded_permuted_rows_row_ptr[elem_index]); |           expanded_permuted_rows_row_ptr[elem_index]); | ||||||
|       thread_output = thread_output + row_scale * (expert_result); |       thread_output = thread_output + row_scale * (expert_result); | ||||||
|  | |||||||
| @ -425,7 +425,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f | |||||||
|  |  | ||||||
| #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB)                       \ | #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB)                       \ | ||||||
|     topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>(         \ |     topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>(         \ | ||||||
|         gating_output, nullptr, topk_weights, topk_indicies,            \ |         gating_output, nullptr, topk_weights, topk_indices,            \ | ||||||
|         token_expert_indices, num_tokens, topk, 0, num_experts,         \ |         token_expert_indices, num_tokens, topk, 0, num_experts,         \ | ||||||
|         stream); |         stream); | ||||||
|  |  | ||||||
| @ -433,7 +433,7 @@ template <typename IndType> | |||||||
| void topkGatingSoftmaxKernelLauncher( | void topkGatingSoftmaxKernelLauncher( | ||||||
|     const float* gating_output, |     const float* gating_output, | ||||||
|     float* topk_weights, |     float* topk_weights, | ||||||
|     IndType* topk_indicies, |     IndType* topk_indices, | ||||||
|     int* token_expert_indices, |     int* token_expert_indices, | ||||||
|     float* softmax_workspace, |     float* softmax_workspace, | ||||||
|     const int num_tokens, |     const int num_tokens, | ||||||
| @ -476,7 +476,7 @@ void topkGatingSoftmaxKernelLauncher( | |||||||
|             moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>( |             moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>( | ||||||
|                 gating_output, nullptr, softmax_workspace, num_experts); |                 gating_output, nullptr, softmax_workspace, num_experts); | ||||||
|             moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>( |             moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>( | ||||||
|                 softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices, |                 softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, | ||||||
|                 num_experts, topk, 0, num_experts); |                 num_experts, topk, 0, num_experts); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @ -492,7 +492,7 @@ void topk_softmax( | |||||||
|     torch::Tensor& gating_output)               // [num_tokens, num_experts] |     torch::Tensor& gating_output)               // [num_tokens, num_experts] | ||||||
| { | { | ||||||
|     const int num_experts = gating_output.size(-1); |     const int num_experts = gating_output.size(-1); | ||||||
|     const int num_tokens = gating_output.numel() / num_experts; |     const auto num_tokens = gating_output.numel() / num_experts; | ||||||
|     const int topk = topk_weights.size(-1); |     const int topk = topk_weights.size(-1); | ||||||
|  |  | ||||||
|     const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); |     const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); | ||||||
|  | |||||||
| @ -22,15 +22,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { | |||||||
|       "                     Tensor! num_tokens_post_pad) -> ()"); |       "                     Tensor! num_tokens_post_pad) -> ()"); | ||||||
|   m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); |   m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); | ||||||
|  |  | ||||||
|   // temporarily adapted from |  | ||||||
|   // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a |  | ||||||
|   m.def( |  | ||||||
|       "sgl_moe_align_block_size(Tensor topk_ids, int num_experts," |  | ||||||
|       "                         int block_size, Tensor! sorted_token_ids," |  | ||||||
|       "                         Tensor! experts_ids," |  | ||||||
|       "                         Tensor! num_tokens_post_pad) -> ()"); |  | ||||||
|   m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); |  | ||||||
|  |  | ||||||
| #ifndef USE_ROCM | #ifndef USE_ROCM | ||||||
|   m.def( |   m.def( | ||||||
|       "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " |       "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " | ||||||
| @ -66,7 +57,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { | |||||||
|  |  | ||||||
|   m.def( |   m.def( | ||||||
|       "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," |       "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," | ||||||
|       "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," |       "Tensor token_expert_indices, Tensor? expert_map, int n_expert," | ||||||
|       "int n_local_expert," |       "int n_local_expert," | ||||||
|       "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " |       "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " | ||||||
|       "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " |       "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " | ||||||
|  | |||||||
							
								
								
									
										27
									
								
								csrc/ops.h
									
									
									
									
									
								
							
							
						
						
									
										27
									
								
								csrc/ops.h
									
									
									
									
									
								
							| @ -326,22 +326,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, | |||||||
|                         const std::optional<torch::Tensor>& has_initial_state, |                         const std::optional<torch::Tensor>& has_initial_state, | ||||||
|                         const torch::Tensor& ssm_states, int64_t pad_slot_id); |                         const torch::Tensor& ssm_states, int64_t pad_slot_id); | ||||||
|  |  | ||||||
| void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, |  | ||||||
|                           const at::Tensor& weight, |  | ||||||
|                           const std::optional<at::Tensor>& bias_, |  | ||||||
|                           bool silu_activation, |  | ||||||
|                           const std::optional<at::Tensor>& cache_seqlens_, |  | ||||||
|                           const std::optional<at::Tensor>& conv_state_indices_, |  | ||||||
|                           int64_t pad_slot_id); |  | ||||||
|  |  | ||||||
| void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, |  | ||||||
|                        const std::optional<at::Tensor>& bias_, |  | ||||||
|                        const std::optional<at::Tensor>& conv_states, |  | ||||||
|                        const std::optional<at::Tensor>& query_start_loc, |  | ||||||
|                        const std::optional<at::Tensor>& cache_indices, |  | ||||||
|                        const std::optional<at::Tensor>& has_initial_state, |  | ||||||
|                        bool silu_activation, int64_t pad_slot_id); |  | ||||||
|  |  | ||||||
| using fptr_t = int64_t; | using fptr_t = int64_t; | ||||||
| fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs, | fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs, | ||||||
|                       torch::Tensor& rank_data, int64_t rank, |                       torch::Tensor& rank_data, int64_t rank, | ||||||
| @ -360,3 +344,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle( | |||||||
|     int64_t size); |     int64_t size); | ||||||
| int64_t open_mem_handle(torch::Tensor& mem_handle); | int64_t open_mem_handle(torch::Tensor& mem_handle); | ||||||
| void free_shared_buffer(int64_t buffer); | void free_shared_buffer(int64_t buffer); | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  | fptr_t init_custom_qr(int64_t rank, int64_t world_size, | ||||||
|  |                       std::optional<int64_t> qr_max_size = std::nullopt); | ||||||
|  | void qr_destroy(fptr_t _fa); | ||||||
|  | torch::Tensor qr_get_handle(fptr_t _fa); | ||||||
|  | void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles); | ||||||
|  | void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, | ||||||
|  |                    int64_t quant_level, bool cast_bf2half = false); | ||||||
|  | int64_t qr_max_size(); | ||||||
|  | #endif | ||||||
| @ -274,7 +274,6 @@ void advance_step_flashinfer( | |||||||
|   cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); |   cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); | ||||||
|   cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); |   cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); | ||||||
|  |  | ||||||
|   [[maybe_unused]] int block_tables_stride = block_tables.stride(0); |  | ||||||
|   TORCH_CHECK((blocks * threads > num_queries), |   TORCH_CHECK((blocks * threads > num_queries), | ||||||
|               "multi-step: not enough threads to map to num_queries = ", |               "multi-step: not enough threads to map to num_queries = ", | ||||||
|               num_queries, " block_tables.stride(0) = ", block_tables.stride(0), |               num_queries, " block_tables.stride(0) = ", block_tables.stride(0), | ||||||
|  | |||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	