mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-11-04 17:34:34 +08:00 
			
		
		
		
	Compare commits
	
		
			551 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8f89d72090 | |||
| 99dac099ab | |||
| c4bd03c7c5 | |||
| dcbf4286af | |||
| 00e6a2dc53 | |||
| 2e02311a1b | |||
| 89ec06c33b | |||
| 9fde251bf0 | |||
| 4c2ffb28ff | |||
| 246598a6b1 | |||
| 8bab4959be | |||
| 3c4cebf751 | |||
| d8f31f2f8b | |||
| 640052b069 | |||
| 351d5e7b82 | |||
| a008629807 | |||
| 76477a93b7 | |||
| 77c87beb06 | |||
| 114332b88e | |||
| cb77ad836f | |||
| 856c990041 | |||
| c5602f0baa | |||
| f7f9c5f97b | |||
| 2c0d933594 | |||
| 774d1035e4 | |||
| 6b29d6fe70 | |||
| 0bfa1c4f13 | |||
| c81da5f56d | |||
| 68bc81703e | |||
| 5884c2b454 | |||
| 45f92c00cf | |||
| 5467ac3196 | |||
| 5d7e3d0176 | |||
| 0373e1837e | |||
| c09dade2a2 | |||
| 8ea5e44a43 | |||
| 9fb900f90c | |||
| c96fc06747 | |||
| b3376e5c76 | |||
| e69ded7d1c | |||
| 767c727a81 | |||
| 6840a71610 | |||
| 7a9cb294ae | |||
| ca3ea51bde | |||
| dc49fb892c | |||
| 18a277b52d | |||
| 8d75fe48ca | |||
| 388596c914 | |||
| baa15a9ec3 | |||
| 15063741e3 | |||
| ccdc490dda | |||
| a31cab7556 | |||
| 828da0d44e | |||
| abe855d637 | |||
| 4efff036f0 | |||
| 89c920785f | |||
| 7b0a0dfb22 | |||
| 3a6ae1d33c | |||
| 8f1729b829 | |||
| 6a7c7711a2 | |||
| 0f83ddd4d7 | |||
| 065aff6c16 | |||
| 3d33e372a1 | |||
| faf71bcd4b | |||
| f270a39537 | |||
| 51a08e7d8f | |||
| eb8fcd2666 | |||
| 5563a4dea8 | |||
| ccd4f129e8 | |||
| 02cc3b51a7 | |||
| d5b1eb081e | |||
| f0a500545f | |||
| c65146e75e | |||
| 41ca62cf03 | |||
| 974fc9b845 | |||
| fee4dcc33a | |||
| 650a4cc55e | |||
| 9ca62d8668 | |||
| 45c35f0d58 | |||
| 9ba093b4f4 | |||
| 27208be66e | |||
| 87d5abef75 | |||
| ec784b2526 | |||
| a58f24e590 | |||
| f42a006b15 | |||
| 3a434b07ed | |||
| bd0e7802e0 | |||
| 06b2550cbb | |||
| f775a07e30 | |||
| 4f0d17c05c | |||
| 10c38e3e46 | |||
| cafb8e06c5 | |||
| cbb2f59cc8 | |||
| 0ab278ca31 | |||
| 7a64d24aad | |||
| dfbe60dc62 | |||
| a66cf40b20 | |||
| f790ad3c50 | |||
| ed59a7ed23 | |||
| 044793d8df | |||
| c2d6d2f960 | |||
| 8279078e21 | |||
| b9c0605a8e | |||
| 37464a0f74 | |||
| c354072828 | |||
| f081c3ce4b | |||
| 260d119e86 | |||
| a360ff80bb | |||
| 1197e02141 | |||
| 657579113f | |||
| e9899fb7a4 | |||
| a377f0bd5e | |||
| e9d3aa04f6 | |||
| a22dea54d3 | |||
| 533c217792 | |||
| 6d21fa1cad | |||
| b35be5403f | |||
| 45a1a69b98 | |||
| 87a658c812 | |||
| 429d89720e | |||
| a9bcc7afb2 | |||
| d79d9eaaff | |||
| f758505c73 | |||
| d910816c73 | |||
| 87d41c849d | |||
| e07aff9e52 | |||
| 5bf185a1c4 | |||
| 4fbcb0f27e | |||
| 7c3604fb68 | |||
| b1c255630d | |||
| eb6c50cdc2 | |||
| eecd864388 | |||
| ae495c74ea | |||
| 4238bc82f2 | |||
| 594392d27a | |||
| 18c1f16d86 | |||
| 5bd3c65072 | |||
| 616e600e0b | |||
| dfba529b40 | |||
| 5ae5ed1e60 | |||
| 290f4ada2b | |||
| dd8de11f0a | |||
| 9ba415588a | |||
| d4f3985907 | |||
| 890aa93d27 | |||
| fbdb7b3ee2 | |||
| 1102bef219 | |||
| f17a1a8f96 | |||
| d5a1697772 | |||
| 325c119961 | |||
| 8e192ff967 | |||
| e64fde4b01 | |||
| 919770957f | |||
| 6a50f4cafa | |||
| e3470f8753 | |||
| a1242324c9 | |||
| 5eda2ea02a | |||
| 2ba80bed27 | |||
| 6066253296 | |||
| ee3eea0a1b | |||
| a36de682d4 | |||
| eb6d3c264d | |||
| 97b030005c | |||
| a3a73ab069 | |||
| 8674f9880e | |||
| c74c913bfb | |||
| 5f6d10c14c | |||
| 9b9a10d6cb | |||
| 99eff67ba9 | |||
| 14772eeb8e | |||
| 757b62c495 | |||
| e941f88584 | |||
| f12c3b5b3d | |||
| d130b573a0 | |||
| 65ae8c2c8f | |||
| c3af44722c | |||
| 1937e29848 | |||
| f0eecee610 | |||
| 943e72ca56 | |||
| 546a97ef69 | |||
| da5a0b539d | |||
| 6287537a0c | |||
| b57e6c5949 | |||
| 27ce85476e | |||
| f68470e803 | |||
| 2e9a2227ec | |||
| c0724fc915 | |||
| 86b45ae065 | |||
| c5711ef985 | |||
| 48d5985a08 | |||
| 33e0823de5 | |||
| 26148120b3 | |||
| 0150a10630 | |||
| 8e7fb5d43a | |||
| 9a31a817a8 | |||
| 2060e93659 | |||
| 8435b207af | |||
| 10fa9eea21 | |||
| e08188081b | |||
| b5853f9963 | |||
| f09edd8a25 | |||
| 6979ade384 | |||
| 9216b9cc38 | |||
| 5e0391c040 | |||
| dbc0754ddf | |||
| 99caa49106 | |||
| 5c342570d7 | |||
| 973617ae02 | |||
| 30e754390c | |||
| 52f8107cf2 | |||
| fc0d9dfc3a | |||
| 361c461a12 | |||
| a5675d348b | |||
| e9cdd2b1e2 | |||
| 65bf2ac165 | |||
| 8a7cc254a0 | |||
| 29bc01bf3b | |||
| 676a99982f | |||
| dc72402b57 | |||
| ccb63a8245 | |||
| c579b750a0 | |||
| 4bfa7e7f75 | |||
| ac1fbf7fd2 | |||
| 33d3914b1e | |||
| 1356df53bd | |||
| ce532ff45c | |||
| 8bc68e198c | |||
| 0fca3cdcf2 | |||
| e7c46b9527 | |||
| 350f9e107f | |||
| 702bee461f | |||
| a7be4d0072 | |||
| a709e87a4f | |||
| 6eaccb7353 | |||
| e254497b66 | |||
| 4e12131089 | |||
| fcc2994be6 | |||
| 2e7796f2cf | |||
| 706588a77d | |||
| 6a0f617210 | |||
| dac6a3f6ed | |||
| 64b77dfd7e | |||
| 51d4094fda | |||
| e965d46184 | |||
| 208b71bcc1 | |||
| c833101740 | |||
| 379da6dcb5 | |||
| ebce310b74 | |||
| be0c5180ac | |||
| cea64430f6 | |||
| a3c124570a | |||
| ff5abcd746 | |||
| 0ee535b294 | |||
| 190bc838e1 | |||
| f12b20decc | |||
| 16bc0a098f | |||
| e288df0632 | |||
| 8b9241be3a | |||
| f942efb5a3 | |||
| 89579a201f | |||
| 230c4b38c1 | |||
| 20cfcdec99 | |||
| ad932a221d | |||
| 5510cf0e8a | |||
| 0f9a6e3d22 | |||
| f6a593093a | |||
| d7740ea4dc | |||
| cc466a3290 | |||
| 8344f7742b | |||
| 469f85c782 | |||
| 10760da800 | |||
| 478aed5827 | |||
| 63575bc2e1 | |||
| a98187cf72 | |||
| bd99d22629 | |||
| 19cb4716ee | |||
| e186d37cb1 | |||
| 323f27b904 | |||
| 0650e5935b | |||
| c7f2cf2b7f | |||
| 8d8357c8ed | |||
| 4302987069 | |||
| 021b1a2ab7 | |||
| 2a052011ca | |||
| 36fb68f947 | |||
| bc8ad68455 | |||
| 344bf7cd2d | |||
| ab50275111 | |||
| 43c413ec57 | |||
| f8e7adda21 | |||
| 7e65477e5e | |||
| 3521ba4f25 | |||
| 2d7bce9cd5 | |||
| ce3f1eedf8 | |||
| 808632d3b4 | |||
| 344a5d0c33 | |||
| 0f8a91401c | |||
| 9b5c9f9484 | |||
| 32881f3f31 | |||
| 5b8a7c1cb0 | |||
| 1ff0c73a79 | |||
| 5ad60b0cbd | |||
| fb087af52e | |||
| 7038e8b803 | |||
| 2a85f93007 | |||
| cf8cac8c70 | |||
| 5e401bce17 | |||
| 0d62fe58db | |||
| b8afa8b95a | |||
| 826b82a260 | |||
| c9d852d601 | |||
| 6ef09b08f8 | |||
| 3a922c1e7e | |||
| c47ba4aaa9 | |||
| 24bb4fe432 | |||
| a657bfc48a | |||
| 24750f4cad | |||
| b38e42fbca | |||
| 8b798eec75 | |||
| 69909126a7 | |||
| e491c7e053 | |||
| 4dc8026d86 | |||
| a88bb9b032 | |||
| 6f1df80436 | |||
| d6f4bd7cdd | |||
| c3845d82dc | |||
| a822eb3413 | |||
| f458112e8a | |||
| 2e240c69a9 | |||
| ee37328da0 | |||
| 6ad58f42c5 | |||
| dd1a50a8bc | |||
| 715c2d854d | |||
| a494140433 | |||
| 111815d482 | |||
| b31a1fb63c | |||
| 4bb53e2dde | |||
| 26f2fb5113 | |||
| fa32207842 | |||
| d627a3d837 | |||
| f4f921b7f1 | |||
| ac5ccf0156 | |||
| 73c8d677e5 | |||
| df29793dc7 | |||
| 03dd7d52bf | |||
| bf480c5302 | |||
| 9c7306ac11 | |||
| 4ea1f9678d | |||
| ba4be44c32 | |||
| d6e520e170 | |||
| 81661da7b2 | |||
| dfea173148 | |||
| 7134303cbb | |||
| 3da24c2df7 | |||
| eefeb16464 | |||
| 18d23f642a | |||
| 87f545ba6f | |||
| 8947bc3c15 | |||
| 12628d3c78 | |||
| 258a2c58d0 | |||
| aba47be3fe | |||
| a62aaf1df5 | |||
| 603ad84815 | |||
| a88081bf76 | |||
| 2f30e7c72f | |||
| a74dee9b62 | |||
| cf29b7eda4 | |||
| efffb63f58 | |||
| 15e7c675b0 | |||
| b6dcb4d442 | |||
| b5b4a398a7 | |||
| f4bc4de1b1 | |||
| bd7a8eef25 | |||
| 7ee82bef1e | |||
| fbf152d976 | |||
| 479d69fad0 | |||
| 96e90fdeb3 | |||
| a395a638c2 | |||
| 2768884ac4 | |||
| aae08249ac | |||
| 7923dcad12 | |||
| 3cd9b5bb2d | |||
| 468d761b32 | |||
| e4bf860a54 | |||
| 91f50a6fe2 | |||
| 79a268c4ab | |||
| eace8bf0b9 | |||
| 1e8f4252aa | |||
| 2b7949c1c2 | |||
| 62b5166bd4 | |||
| d86285a4a4 | |||
| d87f39e9a9 | |||
| d3c8180ac4 | |||
| 62b8aebc6f | |||
| 050f285ff6 | |||
| 8f2ea22bde | |||
| 0ae11f78ab | |||
| 34128a697e | |||
| c1b4e4157c | |||
| ceaf4ed003 | |||
| ad8d696a99 | |||
| 3d925165f2 | |||
| 1543680691 | |||
| 077f0a2e8a | |||
| e73ed0f1c6 | |||
| 296cdf8ac7 | |||
| 747b1a7147 | |||
| 95e5b087cf | |||
| a37d815b83 | |||
| 7f2593b164 | |||
| fe7d648fe5 | |||
| cc74b2b232 | |||
| 91528575ec | |||
| a22cdea371 | |||
| 682789d402 | |||
| 138485a82d | |||
| bc9df1571b | |||
| 15b86408a8 | |||
| 7be4f5628f | |||
| 8f20fc04bf | |||
| 221d93ecbf | |||
| d17c8477f1 | |||
| a134ef6f5e | |||
| 8a7a3e4436 | |||
| 8f9c28fd40 | |||
| cd2f63fb36 | |||
| 87fa80c91f | |||
| e1bb2fd52d | |||
| 705578ae14 | |||
| e8cc7967ff | |||
| 53b018edcb | |||
| 66ded03067 | |||
| 6dc1fc9cfe | |||
| 533d2a1f39 | |||
| a53222544c | |||
| fe3b5bbc23 | |||
| 8438e0569e | |||
| 11d652bd4f | |||
| d150e4f89f | |||
| e95cd87959 | |||
| 69e1d2fb69 | |||
| 05434764cd | |||
| 4e7ee664e2 | |||
| 37e84a403d | |||
| 4695397dcf | |||
| d619ae2d19 | |||
| eb46fbfda2 | |||
| 0003e9154b | |||
| e11e200736 | |||
| 8db1bf32f8 | |||
| aceb17cf2d | |||
| 563c54f760 | |||
| 2cd6b4f362 | |||
| 711a000255 | |||
| 989ae2538d | |||
| 0a430b4ae2 | |||
| ec8e3c695f | |||
| 98afde19fc | |||
| 5c2e66e487 | |||
| 546e721168 | |||
| b8aacac31a | |||
| d04973ad54 | |||
| fbb9d9eef4 | |||
| 09473ee41c | |||
| d4ec9ffb95 | |||
| 96b6a6d790 | |||
| 36729bac13 | |||
| 7fd3949a0b | |||
| 1096717ae9 | |||
| c2b4a1bce9 | |||
| e46a60aa4c | |||
| 1e96c3341a | |||
| 95e7d4a97c | |||
| 559eb852f8 | |||
| a10d3056da | |||
| 8afca50889 | |||
| 08ccee1e83 | |||
| c1dc547129 | |||
| f3d0bf7589 | |||
| e9da5a40c6 | |||
| e42df7227d | |||
| caada5e50a | |||
| 67b4221a61 | |||
| 63e7176f26 | |||
| 934d3662f7 | |||
| 92cd2e2f21 | |||
| e4c4072c94 | |||
| e35397468f | |||
| 8b317c6dd0 | |||
| bd3c144e0b | |||
| 0258b7a94b | |||
| b3104b2a10 | |||
| c2e00af523 | |||
| c013d32c75 | |||
| 11dd6ebb89 | |||
| 6c0b04515f | |||
| e23a43aef8 | |||
| e7c7067b45 | |||
| 6d592eb430 | |||
| d036198e23 | |||
| 59a6abf3c9 | |||
| bc0c0192d1 | |||
| f46864d68d | |||
| b4543c8f6b | |||
| 0ce0539d47 | |||
| 2f19283549 | |||
| 95baec828f | |||
| e4be7d70bb | |||
| 54951ac4bf | |||
| 18de883489 | |||
| 1d7c940d74 | |||
| cfaf49a167 | |||
| 9edec652e2 | |||
| e0dd4d3589 | |||
| e5043a3e75 | |||
| d03d64fd2e | |||
| 78107fa091 | |||
| c391e4b68e | |||
| 9117f892f0 | |||
| db2a6a41e2 | |||
| ca81ff5196 | |||
| b7782002e1 | |||
| 819a309c0f | |||
| aabe8f40f2 | |||
| 498eb5cfa3 | |||
| 537ee25f43 | |||
| 294f8f6665 | |||
| b95047f2da | |||
| 2ff767b513 | |||
| 3dcb3e8b98 | |||
| c64cf38673 | |||
| 76b889bf1d | |||
| c9b506dad4 | |||
| 5757d90e26 | |||
| a3c226e7eb | |||
| b321d4881b | |||
| ad6eca408b | |||
| 205b94942e | |||
| 3bec41f41a | |||
| 0739b1947f | |||
| 77a6572aa5 | |||
| 0e3f06fe9c | |||
| eb69d68804 | |||
| 7d4e1b85e7 | |||
| 93deb0b38f | |||
| ccb58b23e6 | |||
| 49782fcb76 | |||
| f03cc667a0 | |||
| 563c1d7ec5 | |||
| 9c82a1bec3 | |||
| b6d103542c | 
							
								
								
									
										36
									
								
								.buildkite/check-wheel-size.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								.buildkite/check-wheel-size.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
import os
 | 
			
		||||
import zipfile
 | 
			
		||||
 | 
			
		||||
MAX_SIZE_MB = 200
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def print_top_10_largest_files(zip_file):
 | 
			
		||||
    with zipfile.ZipFile(zip_file, 'r') as z:
 | 
			
		||||
        file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
 | 
			
		||||
        file_sizes.sort(key=lambda x: x[1], reverse=True)
 | 
			
		||||
        for f, size in file_sizes[:10]:
 | 
			
		||||
            print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_wheel_size(directory):
 | 
			
		||||
    for root, _, files in os.walk(directory):
 | 
			
		||||
        for f in files:
 | 
			
		||||
            if f.endswith(".whl"):
 | 
			
		||||
                wheel_path = os.path.join(root, f)
 | 
			
		||||
                wheel_size = os.path.getsize(wheel_path)
 | 
			
		||||
                wheel_size_mb = wheel_size / (1024 * 1024)
 | 
			
		||||
                if wheel_size_mb > MAX_SIZE_MB:
 | 
			
		||||
                    print(
 | 
			
		||||
                        f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
 | 
			
		||||
                        f"compare to the allowed size ({MAX_SIZE_MB} MB).")
 | 
			
		||||
                    print_top_10_largest_files(wheel_path)
 | 
			
		||||
                    return 1
 | 
			
		||||
                else:
 | 
			
		||||
                    print(f"Wheel {wheel_path} is within the allowed size "
 | 
			
		||||
                          f"({wheel_size_mb} MB).")
 | 
			
		||||
    return 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    import sys
 | 
			
		||||
    sys.exit(check_wheel_size(sys.argv[1]))
 | 
			
		||||
							
								
								
									
										26
									
								
								.buildkite/nightly-benchmarks/kickoff-pipeline.sh
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										26
									
								
								.buildkite/nightly-benchmarks/kickoff-pipeline.sh
									
									
									
									
									
										Executable file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
#!/usr/bin/env bash
 | 
			
		||||
 | 
			
		||||
set -euo pipefail
 | 
			
		||||
 | 
			
		||||
# Install system packages
 | 
			
		||||
apt update
 | 
			
		||||
apt install -y curl jq
 | 
			
		||||
 | 
			
		||||
# Install minijinja for templating
 | 
			
		||||
curl -sSfL https://github.com/mitsuhiko/minijinja/releases/latest/download/minijinja-cli-installer.sh | sh
 | 
			
		||||
source $HOME/.cargo/env
 | 
			
		||||
 | 
			
		||||
# If BUILDKITE_PULL_REQUEST != "false", then we check the PR labels using curl and jq
 | 
			
		||||
if [ "$BUILDKITE_PULL_REQUEST" != "false" ]; then
 | 
			
		||||
  PR_LABELS=$(curl -s "https://api.github.com/repos/vllm-project/vllm/pulls/$BUILDKITE_PULL_REQUEST" | jq -r '.labels[].name')
 | 
			
		||||
 | 
			
		||||
  if [[ $PR_LABELS == *"perf-benchmarks"* ]]; then
 | 
			
		||||
    echo "This PR has the 'perf-benchmarks' label. Proceeding with the nightly benchmarks."
 | 
			
		||||
  else
 | 
			
		||||
    echo "This PR does not have the 'perf-benchmarks' label. Skipping the nightly benchmarks."
 | 
			
		||||
    exit 0
 | 
			
		||||
  fi
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Upload sample.yaml
 | 
			
		||||
buildkite-agent pipeline upload .buildkite/nightly-benchmarks/sample.yaml
 | 
			
		||||
							
								
								
									
										39
									
								
								.buildkite/nightly-benchmarks/sample.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								.buildkite/nightly-benchmarks/sample.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
			
		||||
steps:
 | 
			
		||||
  # NOTE(simon): You can create separate blocks for different jobs
 | 
			
		||||
  - label: "A100: NVIDIA SMI"
 | 
			
		||||
    agents:
 | 
			
		||||
      queue: A100
 | 
			
		||||
    plugins:
 | 
			
		||||
    - kubernetes:
 | 
			
		||||
        podSpec:
 | 
			
		||||
          containers:
 | 
			
		||||
          # - image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT
 | 
			
		||||
          # TODO(simon): check latest main branch or use the PR image.
 | 
			
		||||
          - image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
 | 
			
		||||
            command:
 | 
			
		||||
            - bash -c 'nvidia-smi && nvidia-smi topo -m && pwd && ls'
 | 
			
		||||
            resources:
 | 
			
		||||
              limits:
 | 
			
		||||
                nvidia.com/gpu: 8
 | 
			
		||||
            volumeMounts:
 | 
			
		||||
            - name: devshm
 | 
			
		||||
              mountPath: /dev/shm
 | 
			
		||||
          nodeSelector:
 | 
			
		||||
            nvidia.com/gpu.product: NVIDIA-A100-SXM4-80GB
 | 
			
		||||
          volumes:
 | 
			
		||||
          - name: devshm
 | 
			
		||||
            emptyDir:
 | 
			
		||||
              medium: Memory
 | 
			
		||||
  # TODO(simon): bring H100 online
 | 
			
		||||
  # - label: "H100: NVIDIA SMI"
 | 
			
		||||
  #   agents:
 | 
			
		||||
  #     queue: H100
 | 
			
		||||
  #   plugins:
 | 
			
		||||
  #   - docker#v5.11.0:
 | 
			
		||||
  #       image: us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:45c35f0d58f4508bf43bd6af1d3d0d0ec0c915e6
 | 
			
		||||
  #       command:
 | 
			
		||||
  #       - bash -c 'nvidia-smi && nvidia-smi topo -m'
 | 
			
		||||
  #       propagate-environment: true
 | 
			
		||||
  #       ipc: host
 | 
			
		||||
  #       gpus: all
 | 
			
		||||
 | 
			
		||||
@ -1,38 +1,73 @@
 | 
			
		||||
# This script build the ROCm docker image and run the API server inside the container.
 | 
			
		||||
# It serves a sanity check for compilation and basic model usage.
 | 
			
		||||
# This script runs test inside the corresponding ROCm docker container.
 | 
			
		||||
set -ex
 | 
			
		||||
 | 
			
		||||
# Print ROCm version
 | 
			
		||||
echo "--- ROCm info"
 | 
			
		||||
rocminfo
 | 
			
		||||
 | 
			
		||||
# Try building the docker image
 | 
			
		||||
docker build -t rocm -f Dockerfile.rocm .
 | 
			
		||||
 | 
			
		||||
# Setup cleanup
 | 
			
		||||
remove_docker_container() { docker rm -f rocm || true; }
 | 
			
		||||
trap remove_docker_container EXIT
 | 
			
		||||
remove_docker_container
 | 
			
		||||
 | 
			
		||||
# Run the image
 | 
			
		||||
docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server &
 | 
			
		||||
 | 
			
		||||
# Wait for the server to start
 | 
			
		||||
wait_for_server_to_start() {
 | 
			
		||||
    timeout=300
 | 
			
		||||
    counter=0
 | 
			
		||||
 | 
			
		||||
    while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
 | 
			
		||||
        sleep 1
 | 
			
		||||
        counter=$((counter + 1))
 | 
			
		||||
        if [ $counter -ge $timeout ]; then
 | 
			
		||||
            echo "Timeout after $timeout seconds"
 | 
			
		||||
            break
 | 
			
		||||
        fi
 | 
			
		||||
    done
 | 
			
		||||
# cleanup older docker images
 | 
			
		||||
cleanup_docker() {
 | 
			
		||||
  # Get Docker's root directory
 | 
			
		||||
  docker_root=$(docker info -f '{{.DockerRootDir}}')
 | 
			
		||||
  if [ -z "$docker_root" ]; then
 | 
			
		||||
    echo "Failed to determine Docker root directory."
 | 
			
		||||
    exit 1
 | 
			
		||||
  fi
 | 
			
		||||
  echo "Docker root directory: $docker_root"
 | 
			
		||||
  # Check disk usage of the filesystem where Docker's root directory is located
 | 
			
		||||
  disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
 | 
			
		||||
  # Define the threshold
 | 
			
		||||
  threshold=70
 | 
			
		||||
  if [ "$disk_usage" -gt "$threshold" ]; then
 | 
			
		||||
    echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
 | 
			
		||||
    # Remove dangling images (those that are not tagged and not used by any container)
 | 
			
		||||
    docker image prune -f
 | 
			
		||||
    # Remove unused volumes
 | 
			
		||||
    docker volume prune -f
 | 
			
		||||
    echo "Docker images and volumes cleanup completed."
 | 
			
		||||
  else
 | 
			
		||||
    echo "Disk usage is below $threshold%. No cleanup needed."
 | 
			
		||||
  fi
 | 
			
		||||
}
 | 
			
		||||
wait_for_server_to_start
 | 
			
		||||
 | 
			
		||||
# Test a simple prompt
 | 
			
		||||
curl -X POST -H "Content-Type: application/json" \
 | 
			
		||||
    localhost:8000/generate \
 | 
			
		||||
    -d '{"prompt": "San Francisco is a"}'
 | 
			
		||||
# Call the cleanup docker function
 | 
			
		||||
cleanup_docker
 | 
			
		||||
 | 
			
		||||
echo "--- Resetting GPUs"
 | 
			
		||||
 | 
			
		||||
echo "reset" > /opt/amdgpu/etc/gpu_state
 | 
			
		||||
 | 
			
		||||
while true; do
 | 
			
		||||
        sleep 3
 | 
			
		||||
        if grep -q clean /opt/amdgpu/etc/gpu_state; then
 | 
			
		||||
                echo "GPUs state is \"clean\""
 | 
			
		||||
                break
 | 
			
		||||
        fi
 | 
			
		||||
done
 | 
			
		||||
 | 
			
		||||
echo "--- Building container"
 | 
			
		||||
sha=$(git rev-parse --short HEAD)
 | 
			
		||||
image_name=rocm_${sha}
 | 
			
		||||
container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)
 | 
			
		||||
docker build \
 | 
			
		||||
        -t ${image_name} \
 | 
			
		||||
        -f Dockerfile.rocm \
 | 
			
		||||
        --progress plain \
 | 
			
		||||
        .
 | 
			
		||||
 | 
			
		||||
remove_docker_container() {
 | 
			
		||||
   docker rm -f ${container_name} || docker image rm -f ${image_name} || true
 | 
			
		||||
}
 | 
			
		||||
trap remove_docker_container EXIT
 | 
			
		||||
 | 
			
		||||
echo "--- Running container"
 | 
			
		||||
 | 
			
		||||
docker run \
 | 
			
		||||
        --device /dev/kfd --device /dev/dri \
 | 
			
		||||
        --network host \
 | 
			
		||||
        --rm \
 | 
			
		||||
        -e HF_TOKEN \
 | 
			
		||||
        --name ${container_name} \
 | 
			
		||||
        ${image_name} \
 | 
			
		||||
        /bin/bash -c "${@}"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.."
 | 
			
		||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
 | 
			
		||||
 | 
			
		||||
# run python-based benchmarks and upload the result to buildkite
 | 
			
		||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
 | 
			
		||||
python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
 | 
			
		||||
bench_latency_exit_code=$?
 | 
			
		||||
 | 
			
		||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
 | 
			
		||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
 | 
			
		||||
bench_throughput_exit_code=$?
 | 
			
		||||
 | 
			
		||||
# run server-based benchmarks and upload the result to buildkite
 | 
			
		||||
@ -50,11 +50,16 @@ echo "### Serving Benchmarks" >> benchmark_results.md
 | 
			
		||||
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
 | 
			
		||||
echo "" >> benchmark_results.md
 | 
			
		||||
echo '```' >> benchmark_results.md
 | 
			
		||||
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines
 | 
			
		||||
tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines
 | 
			
		||||
echo '```' >> benchmark_results.md
 | 
			
		||||
 | 
			
		||||
# if the agent binary is not found, skip uploading the results, exit 0
 | 
			
		||||
if [ ! -f /usr/bin/buildkite-agent ]; then
 | 
			
		||||
    exit 0
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# upload the results to buildkite
 | 
			
		||||
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
 | 
			
		||||
buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
 | 
			
		||||
 | 
			
		||||
# exit with the exit code of the benchmarks
 | 
			
		||||
if [ $bench_latency_exit_code -ne 0 ]; then
 | 
			
		||||
@ -69,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then
 | 
			
		||||
    exit $bench_serving_exit_code
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
/workspace/buildkite-agent artifact upload openai-*.json
 | 
			
		||||
rm ShareGPT_V3_unfiltered_cleaned_split.json
 | 
			
		||||
buildkite-agent artifact upload "*.json"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								.buildkite/run-cpu-test.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								.buildkite/run-cpu-test.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
# This script build the CPU docker image and run the offline inference inside the container.
 | 
			
		||||
# It serves a sanity check for compilation and basic model usage.
 | 
			
		||||
set -ex
 | 
			
		||||
 | 
			
		||||
# Try building the docker image
 | 
			
		||||
docker build -t cpu-test -f Dockerfile.cpu .
 | 
			
		||||
 | 
			
		||||
# Setup cleanup
 | 
			
		||||
remove_docker_container() { docker rm -f cpu-test || true; }
 | 
			
		||||
trap remove_docker_container EXIT
 | 
			
		||||
remove_docker_container
 | 
			
		||||
 | 
			
		||||
# Run the image
 | 
			
		||||
docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
 | 
			
		||||
 | 
			
		||||
# offline inference
 | 
			
		||||
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
 | 
			
		||||
 | 
			
		||||
# Run basic model test
 | 
			
		||||
docker exec cpu-test bash -c "cd tests;
 | 
			
		||||
  pip install pytest Pillow protobuf
 | 
			
		||||
  bash ../.buildkite/download-images.sh
 | 
			
		||||
  cd ../
 | 
			
		||||
  pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
 | 
			
		||||
							
								
								
									
										51
									
								
								.buildkite/run-neuron-test.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								.buildkite/run-neuron-test.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
# This script build the Neuron docker image and run the API server inside the container.
 | 
			
		||||
# It serves a sanity check for compilation and basic model usage.
 | 
			
		||||
set -e
 | 
			
		||||
 | 
			
		||||
# Try building the docker image
 | 
			
		||||
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
 | 
			
		||||
 | 
			
		||||
# prune old image and containers to save disk space, and only once a day
 | 
			
		||||
# by using a timestamp file in tmp.
 | 
			
		||||
if [ -f /tmp/neuron-docker-build-timestamp ]; then
 | 
			
		||||
    last_build=$(cat /tmp/neuron-docker-build-timestamp)
 | 
			
		||||
    current_time=$(date +%s)
 | 
			
		||||
    if [ $((current_time - last_build)) -gt 86400 ]; then
 | 
			
		||||
        docker system prune -f
 | 
			
		||||
        echo $current_time > /tmp/neuron-docker-build-timestamp
 | 
			
		||||
    fi
 | 
			
		||||
else
 | 
			
		||||
    echo $(date +%s) > /tmp/neuron-docker-build-timestamp
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
docker build -t neuron -f Dockerfile.neuron .
 | 
			
		||||
 | 
			
		||||
# Setup cleanup
 | 
			
		||||
remove_docker_container() { docker rm -f neuron || true; }
 | 
			
		||||
trap remove_docker_container EXIT
 | 
			
		||||
remove_docker_container
 | 
			
		||||
 | 
			
		||||
# Run the image
 | 
			
		||||
docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \
 | 
			
		||||
       --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 &
 | 
			
		||||
 | 
			
		||||
# Wait for the server to start
 | 
			
		||||
wait_for_server_to_start() {
 | 
			
		||||
    timeout=300
 | 
			
		||||
    counter=0
 | 
			
		||||
 | 
			
		||||
    while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
 | 
			
		||||
        sleep 1
 | 
			
		||||
        counter=$((counter + 1))
 | 
			
		||||
        if [ $counter -ge $timeout ]; then
 | 
			
		||||
            echo "Timeout after $timeout seconds"
 | 
			
		||||
            break
 | 
			
		||||
        fi
 | 
			
		||||
    done
 | 
			
		||||
}
 | 
			
		||||
wait_for_server_to_start
 | 
			
		||||
 | 
			
		||||
# Test a simple prompt
 | 
			
		||||
curl -X POST -H "Content-Type: application/json" \
 | 
			
		||||
    localhost:8000/generate \
 | 
			
		||||
    -d '{"prompt": "San Francisco is a"}'
 | 
			
		||||
@ -5,92 +5,164 @@
 | 
			
		||||
 | 
			
		||||
steps:
 | 
			
		||||
- label: Regression Test
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s test_regression.py
 | 
			
		||||
  working_dir: "/vllm-workspace/tests" # optional
 | 
			
		||||
 | 
			
		||||
- label: AsyncEngine Test
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s async_engine
 | 
			
		||||
 | 
			
		||||
- label: Basic Correctness Test
 | 
			
		||||
  command: pytest -v -s basic_correctness
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
  - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
 | 
			
		||||
  - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
 | 
			
		||||
  - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
 | 
			
		||||
  - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
 | 
			
		||||
  - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
 | 
			
		||||
 | 
			
		||||
- label: Core Test
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s core
 | 
			
		||||
 | 
			
		||||
- label: Distributed Comm Ops Test
 | 
			
		||||
  command: pytest -v -s test_comm_ops.py
 | 
			
		||||
  working_dir: "/vllm-workspace/tests/distributed"
 | 
			
		||||
  num_gpus: 2 # only support 1 or 2 for now.
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s distributed/test_comm_ops.py
 | 
			
		||||
  working_dir: "/vllm-workspace/tests"
 | 
			
		||||
  num_gpus: 2
 | 
			
		||||
 | 
			
		||||
- label: Distributed Tests
 | 
			
		||||
  working_dir: "/vllm-workspace/tests/distributed"
 | 
			
		||||
  num_gpus: 2 # only support 1 or 2 for now.
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  working_dir: "/vllm-workspace/tests"
 | 
			
		||||
  num_gpus: 2
 | 
			
		||||
  commands:
 | 
			
		||||
  - pytest -v -s test_pynccl.py
 | 
			
		||||
  - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
 | 
			
		||||
  - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
 | 
			
		||||
  - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
 | 
			
		||||
  - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
 | 
			
		||||
  - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
 | 
			
		||||
  - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
 | 
			
		||||
  - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
 | 
			
		||||
  - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
 | 
			
		||||
  - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
 | 
			
		||||
  - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
 | 
			
		||||
  - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
 | 
			
		||||
  - pytest -v -s spec_decode/e2e/test_integration_dist.py
 | 
			
		||||
  - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
 | 
			
		||||
 | 
			
		||||
- label: Distributed Tests (Multiple Groups)
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  working_dir: "/vllm-workspace/tests"
 | 
			
		||||
  num_gpus: 4
 | 
			
		||||
  commands:
 | 
			
		||||
  - pytest -v -s distributed/test_pynccl.py
 | 
			
		||||
 | 
			
		||||
- label: Engine Test
 | 
			
		||||
  command: pytest -v -s engine tokenization test_sequence.py test_config.py
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
 | 
			
		||||
 | 
			
		||||
- label: Entrypoints Test
 | 
			
		||||
  command: pytest -v -s entrypoints
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
 | 
			
		||||
  commands:
 | 
			
		||||
  - pytest -v -s entrypoints -m llm
 | 
			
		||||
  - pytest -v -s entrypoints -m openai
 | 
			
		||||
 | 
			
		||||
- label: Examples Test
 | 
			
		||||
  working_dir: "/vllm-workspace/examples"
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
    # install aws cli for llava_example.py
 | 
			
		||||
    - pip install awscli
 | 
			
		||||
    # install tensorizer for tensorize_vllm_model.py
 | 
			
		||||
    - pip install awscli tensorizer
 | 
			
		||||
    - python3 offline_inference.py
 | 
			
		||||
    - python3 offline_inference_with_prefix.py
 | 
			
		||||
    - python3 llm_engine_example.py
 | 
			
		||||
    - python3 llava_example.py
 | 
			
		||||
    - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
 | 
			
		||||
 | 
			
		||||
- label: Inputs Test
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
    - bash ../.buildkite/download-images.sh
 | 
			
		||||
    - pytest -v -s test_inputs.py
 | 
			
		||||
    - pytest -v -s multimodal
 | 
			
		||||
 | 
			
		||||
- label: Kernels Test %N
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
 | 
			
		||||
  parallelism: 4
 | 
			
		||||
 | 
			
		||||
- label: Models Test
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
    - bash ../.buildkite/download-images.sh
 | 
			
		||||
    - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
 | 
			
		||||
    - pytest -v -s models -m \"not llava\"
 | 
			
		||||
 | 
			
		||||
- label: Llava Test
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
    - bash ../.buildkite/download-images.sh
 | 
			
		||||
    - pytest -v -s models/test_llava.py
 | 
			
		||||
    - pytest -v -s models -m llava
 | 
			
		||||
 | 
			
		||||
- label: Prefix Caching Test
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
    - pytest -v -s prefix_caching
 | 
			
		||||
 | 
			
		||||
- label: Samplers Test
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s samplers
 | 
			
		||||
 | 
			
		||||
- label: LogitsProcessor Test
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s test_logits_processor.py
 | 
			
		||||
 | 
			
		||||
- label: Utils Test
 | 
			
		||||
  command: pytest -v -s test_utils.py
 | 
			
		||||
 | 
			
		||||
- label: Worker Test
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s worker
 | 
			
		||||
 | 
			
		||||
- label: Speculative decoding tests
 | 
			
		||||
  command: pytest -v -s spec_decode
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
    # See https://github.com/vllm-project/vllm/issues/5152
 | 
			
		||||
    - export VLLM_ATTENTION_BACKEND=XFORMERS
 | 
			
		||||
    - pytest -v -s spec_decode
 | 
			
		||||
 | 
			
		||||
- label: LoRA Test %N
 | 
			
		||||
  command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
 | 
			
		||||
  parallelism: 4
 | 
			
		||||
 | 
			
		||||
- label: LoRA Long Context (Distributed)
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  num_gpus: 4
 | 
			
		||||
  # This test runs llama 13B, so it is required to run on 4 GPUs.
 | 
			
		||||
  commands:
 | 
			
		||||
    - pytest -v -s -x lora/test_long_context.py
 | 
			
		||||
 | 
			
		||||
- label: Tensorizer Test
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
 | 
			
		||||
 | 
			
		||||
- label: Metrics Test
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s metrics
 | 
			
		||||
 | 
			
		||||
- label: Quantization Test
 | 
			
		||||
  #mirror_hardwares: [amd]
 | 
			
		||||
  command: pytest -v -s quantization
 | 
			
		||||
 | 
			
		||||
- label: Benchmarks
 | 
			
		||||
  working_dir: "/vllm-workspace/.buildkite"
 | 
			
		||||
  mirror_hardwares: [amd]
 | 
			
		||||
  commands:
 | 
			
		||||
  - pip install aiohttp
 | 
			
		||||
  - bash run-benchmarks.sh
 | 
			
		||||
 | 
			
		||||
- label: Documentation Build
 | 
			
		||||
  working_dir: "/vllm-workspace/docs"
 | 
			
		||||
  working_dir: "/vllm-workspace/test_docs/docs"
 | 
			
		||||
  no_gpu: True
 | 
			
		||||
  commands:
 | 
			
		||||
  - pip install -r requirements-docs.txt
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										64
									
								
								.buildkite/test-template-aws.j2
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								.buildkite/test-template-aws.j2
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
			
		||||
{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %}
 | 
			
		||||
{% set default_working_dir = "/vllm-workspace/tests" %}
 | 
			
		||||
 | 
			
		||||
steps:
 | 
			
		||||
  - label: ":docker: build image"
 | 
			
		||||
    agents:
 | 
			
		||||
      queue: cpu_queue
 | 
			
		||||
    commands:
 | 
			
		||||
      - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
 | 
			
		||||
      - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
 | 
			
		||||
      - "docker push {{ docker_image }}"
 | 
			
		||||
    env:
 | 
			
		||||
      DOCKER_BUILDKIT: "1"
 | 
			
		||||
    retry:
 | 
			
		||||
      automatic:
 | 
			
		||||
        - exit_status: -1  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
        - exit_status: -10  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
  - wait
 | 
			
		||||
 | 
			
		||||
  {% for step in steps %}
 | 
			
		||||
  - label: "{{ step.label }}"
 | 
			
		||||
    agents:
 | 
			
		||||
      {% if step.label == "Documentation Build" %}
 | 
			
		||||
      queue: small_cpu_queue
 | 
			
		||||
      {% elif step.no_gpu %}
 | 
			
		||||
      queue: cpu_queue
 | 
			
		||||
      {% elif step.num_gpus == 2 or step.num_gpus == 4 %}
 | 
			
		||||
      queue: gpu_4_queue
 | 
			
		||||
      {% else %}
 | 
			
		||||
      queue: gpu_1_queue
 | 
			
		||||
      {% endif %}
 | 
			
		||||
    soft_fail: true
 | 
			
		||||
    {% if step.parallelism %}
 | 
			
		||||
    parallelism: {{ step.parallelism }}
 | 
			
		||||
    {% endif %}
 | 
			
		||||
    retry:
 | 
			
		||||
      automatic:
 | 
			
		||||
        - exit_status: -1  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
        - exit_status: -10  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
    plugins:
 | 
			
		||||
      - docker#v5.2.0:
 | 
			
		||||
          image: {{ docker_image }}
 | 
			
		||||
          always-pull: true
 | 
			
		||||
          propagate-environment: true
 | 
			
		||||
          {% if not step.no_gpu %}
 | 
			
		||||
          gpus: all
 | 
			
		||||
          {% endif %}
 | 
			
		||||
          {% if step.label == "Benchmarks" %}
 | 
			
		||||
          mount-buildkite-agent: true
 | 
			
		||||
          {% endif %}
 | 
			
		||||
          command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe  }} && {{ step.command  or (step.commands | join(' && ')) | safe }}"]
 | 
			
		||||
          environment:
 | 
			
		||||
            - VLLM_USAGE_SOURCE=ci-test
 | 
			
		||||
            - HF_TOKEN
 | 
			
		||||
            {% if step.label == "Speculative decoding tests" %}
 | 
			
		||||
            - VLLM_ATTENTION_BACKEND=XFORMERS
 | 
			
		||||
            {% endif %}
 | 
			
		||||
          volumes:
 | 
			
		||||
            - /dev/shm:/dev/shm
 | 
			
		||||
  {% endfor %}
 | 
			
		||||
@ -3,11 +3,6 @@
 | 
			
		||||
{% set default_working_dir = "/vllm-workspace/tests" %}
 | 
			
		||||
 | 
			
		||||
steps:
 | 
			
		||||
  - label: "AMD Test"
 | 
			
		||||
    agents:
 | 
			
		||||
      queue: amd
 | 
			
		||||
    command: bash .buildkite/run-amd-test.sh
 | 
			
		||||
 | 
			
		||||
  - label: ":docker: build image"
 | 
			
		||||
    commands:
 | 
			
		||||
      - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
 | 
			
		||||
@ -18,8 +13,38 @@ steps:
 | 
			
		||||
      automatic:
 | 
			
		||||
        - exit_status: -1  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
        - exit_status: -10  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
  - wait
 | 
			
		||||
 | 
			
		||||
  - group: "AMD Tests"
 | 
			
		||||
    depends_on: ~
 | 
			
		||||
    steps:
 | 
			
		||||
    {% for step in steps %}
 | 
			
		||||
    {% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
 | 
			
		||||
      - label: "AMD: {{ step.label }}"
 | 
			
		||||
        agents:
 | 
			
		||||
          queue: amd
 | 
			
		||||
        command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe  }} ; {{ step.command  or (step.commands | join(" ; ")) | safe }}"
 | 
			
		||||
        env:
 | 
			
		||||
          DOCKER_BUILDKIT: "1"
 | 
			
		||||
        soft_fail: true
 | 
			
		||||
    {% endif %}
 | 
			
		||||
    {% endfor %}
 | 
			
		||||
 | 
			
		||||
  - label: "Neuron Test"
 | 
			
		||||
    depends_on: ~
 | 
			
		||||
    agents:
 | 
			
		||||
      queue: neuron
 | 
			
		||||
    command: bash .buildkite/run-neuron-test.sh
 | 
			
		||||
    soft_fail: false
 | 
			
		||||
 | 
			
		||||
  - label: "Intel Test"
 | 
			
		||||
    depends_on: ~
 | 
			
		||||
    agents:
 | 
			
		||||
      queue: intel
 | 
			
		||||
    command: bash .buildkite/run-cpu-test.sh
 | 
			
		||||
 | 
			
		||||
  {% for step in steps %}
 | 
			
		||||
  - label: "{{ step.label }}"
 | 
			
		||||
    agents:
 | 
			
		||||
@ -32,9 +57,14 @@ steps:
 | 
			
		||||
      automatic:
 | 
			
		||||
        - exit_status: -1  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
        - exit_status: -10  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
    plugins:
 | 
			
		||||
      - kubernetes:
 | 
			
		||||
          podSpec:
 | 
			
		||||
            {% if step.num_gpus %}
 | 
			
		||||
            priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
 | 
			
		||||
            {% endif %}
 | 
			
		||||
            volumes:
 | 
			
		||||
              - name: dshm
 | 
			
		||||
                emptyDir:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										26
									
								
								.clang-format
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								.clang-format
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
BasedOnStyle: Google
 | 
			
		||||
UseTab: Never
 | 
			
		||||
IndentWidth: 2
 | 
			
		||||
ColumnLimit: 80
 | 
			
		||||
 | 
			
		||||
# Force pointers to the type for C++.
 | 
			
		||||
DerivePointerAlignment: false
 | 
			
		||||
PointerAlignment: Left
 | 
			
		||||
 | 
			
		||||
# Reordering #include statements can (and currently will) introduce errors
 | 
			
		||||
SortIncludes: false
 | 
			
		||||
 | 
			
		||||
# Style choices
 | 
			
		||||
AlignConsecutiveAssignments: false
 | 
			
		||||
AlignConsecutiveDeclarations: false
 | 
			
		||||
IndentPPDirectives: BeforeHash
 | 
			
		||||
 | 
			
		||||
IncludeCategories:
 | 
			
		||||
  - Regex:           '^<'
 | 
			
		||||
    Priority:        4
 | 
			
		||||
  - Regex:           '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
 | 
			
		||||
    Priority:        3
 | 
			
		||||
  - Regex:           '^"(qoda|\.\.)/'
 | 
			
		||||
    Priority:        2
 | 
			
		||||
  - Regex:           '.*'
 | 
			
		||||
    Priority:        1
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/ISSUE_TEMPLATE/200-installation.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/ISSUE_TEMPLATE/200-installation.yml
									
									
									
									
										vendored
									
									
								
							@ -18,6 +18,7 @@ body:
 | 
			
		||||
      # For security purposes, please feel free to check the contents of collect_env.py before running it.
 | 
			
		||||
      python collect_env.py
 | 
			
		||||
      ```
 | 
			
		||||
      It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
 | 
			
		||||
    value: |
 | 
			
		||||
      ```text
 | 
			
		||||
      The output of `python collect_env.py`
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/ISSUE_TEMPLATE/300-usage.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/ISSUE_TEMPLATE/300-usage.yml
									
									
									
									
										vendored
									
									
								
							@ -18,6 +18,7 @@ body:
 | 
			
		||||
      # For security purposes, please feel free to check the contents of collect_env.py before running it.
 | 
			
		||||
      python collect_env.py
 | 
			
		||||
      ```
 | 
			
		||||
      It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
 | 
			
		||||
    value: |
 | 
			
		||||
      ```text
 | 
			
		||||
      The output of `python collect_env.py`
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/ISSUE_TEMPLATE/400-bug report.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/ISSUE_TEMPLATE/400-bug report.yml
									
									
									
									
										vendored
									
									
								
							@ -18,6 +18,7 @@ body:
 | 
			
		||||
      # For security purposes, please feel free to check the contents of collect_env.py before running it.
 | 
			
		||||
      python collect_env.py
 | 
			
		||||
      ```
 | 
			
		||||
      It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
 | 
			
		||||
    value: |
 | 
			
		||||
      ```text
 | 
			
		||||
      The output of `python collect_env.py`
 | 
			
		||||
@ -57,6 +58,10 @@ body:
 | 
			
		||||
      If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
 | 
			
		||||
 | 
			
		||||
      Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
 | 
			
		||||
 | 
			
		||||
      Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues.
 | 
			
		||||
 | 
			
		||||
      If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
 | 
			
		||||
    placeholder: |
 | 
			
		||||
      A clear and concise description of what the bug is.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -39,6 +39,7 @@ body:
 | 
			
		||||
      # For security purposes, please feel free to check the contents of collect_env.py before running it.
 | 
			
		||||
      python collect_env.py
 | 
			
		||||
      ```
 | 
			
		||||
      It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
 | 
			
		||||
    value: |
 | 
			
		||||
      ```text
 | 
			
		||||
      The output of `python collect_env.py`
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										49
									
								
								.github/ISSUE_TEMPLATE/750-RFC.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								.github/ISSUE_TEMPLATE/750-RFC.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,49 @@
 | 
			
		||||
name: 💬 Request for comments (RFC).
 | 
			
		||||
description: Ask for feedback on major architectural changes or design choices.
 | 
			
		||||
title: "[RFC]: "
 | 
			
		||||
labels: ["RFC"]
 | 
			
		||||
 | 
			
		||||
body:
 | 
			
		||||
- type: markdown
 | 
			
		||||
  attributes:
 | 
			
		||||
    value: >
 | 
			
		||||
      #### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference.
 | 
			
		||||
- type: textarea
 | 
			
		||||
  attributes:
 | 
			
		||||
    label: Motivation.
 | 
			
		||||
    description: >
 | 
			
		||||
      The motivation of the RFC.
 | 
			
		||||
  validations:
 | 
			
		||||
    required: true
 | 
			
		||||
- type: textarea
 | 
			
		||||
  attributes:
 | 
			
		||||
    label: Proposed Change.
 | 
			
		||||
    description: >
 | 
			
		||||
      The proposed change of the RFC.
 | 
			
		||||
  validations:
 | 
			
		||||
    required: true
 | 
			
		||||
- type: textarea
 | 
			
		||||
  attributes:
 | 
			
		||||
    label: Feedback Period.
 | 
			
		||||
    description: >
 | 
			
		||||
      The feedback period of the RFC. Usually at least one week.
 | 
			
		||||
  validations:
 | 
			
		||||
    required: false
 | 
			
		||||
- type: textarea
 | 
			
		||||
  attributes:
 | 
			
		||||
    label: CC List.
 | 
			
		||||
    description: >
 | 
			
		||||
      The list of people you want to CC.
 | 
			
		||||
  validations:
 | 
			
		||||
    required: false
 | 
			
		||||
- type: textarea
 | 
			
		||||
  attributes:
 | 
			
		||||
    label: Any Other Things.
 | 
			
		||||
    description: >
 | 
			
		||||
      Any other things you would like to mention.
 | 
			
		||||
  validations:
 | 
			
		||||
    required: false
 | 
			
		||||
- type: markdown
 | 
			
		||||
  attributes:
 | 
			
		||||
    value: >
 | 
			
		||||
      Thanks for contributing 🎉!
 | 
			
		||||
							
								
								
									
										42
									
								
								.github/workflows/clang-format.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								.github/workflows/clang-format.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
name: clang-format
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  # Trigger the workflow on push or pull request,
 | 
			
		||||
  # but only for the main branch
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
  pull_request:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  clang-format:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: ["3.11"]
 | 
			
		||||
    steps:
 | 
			
		||||
    - uses: actions/checkout@v2
 | 
			
		||||
    - name: Set up Python ${{ matrix.python-version }}
 | 
			
		||||
      uses: actions/setup-python@v2
 | 
			
		||||
      with:
 | 
			
		||||
        python-version: ${{ matrix.python-version }}
 | 
			
		||||
    - name: Install dependencies
 | 
			
		||||
      run: |
 | 
			
		||||
        python -m pip install --upgrade pip
 | 
			
		||||
        pip install clang-format==18.1.5
 | 
			
		||||
    - name: Running clang-format
 | 
			
		||||
      run: |
 | 
			
		||||
        EXCLUDES=(
 | 
			
		||||
            'csrc/moe/topk_softmax_kernels.cu'
 | 
			
		||||
            'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
 | 
			
		||||
            'csrc/punica/bgmv/bgmv_config.h'
 | 
			
		||||
            'csrc/punica/bgmv/bgmv_impl.cuh'
 | 
			
		||||
            'csrc/punica/bgmv/vec_dtypes.cuh'
 | 
			
		||||
            'csrc/punica/punica_ops.cu'
 | 
			
		||||
            'csrc/punica/type_convert.h'
 | 
			
		||||
        )
 | 
			
		||||
        find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
 | 
			
		||||
            | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
 | 
			
		||||
            | xargs clang-format --dry-run --Werror
 | 
			
		||||
							
								
								
									
										51
									
								
								.github/workflows/mypy.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								.github/workflows/mypy.yaml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
name: mypy
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  # Trigger the workflow on push or pull request,
 | 
			
		||||
  # but only for the main branch
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
  pull_request:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  ruff:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: ["3.8", "3.9", "3.10", "3.11"]
 | 
			
		||||
    steps:
 | 
			
		||||
    - uses: actions/checkout@v2
 | 
			
		||||
    - name: Set up Python ${{ matrix.python-version }}
 | 
			
		||||
      uses: actions/setup-python@v2
 | 
			
		||||
      with:
 | 
			
		||||
        python-version: ${{ matrix.python-version }}
 | 
			
		||||
    - name: Install dependencies
 | 
			
		||||
      run: |
 | 
			
		||||
        python -m pip install --upgrade pip
 | 
			
		||||
        pip install mypy==1.9.0
 | 
			
		||||
        pip install types-setuptools
 | 
			
		||||
        pip install types-PyYAML
 | 
			
		||||
        pip install types-requests
 | 
			
		||||
        pip install types-setuptools
 | 
			
		||||
    - name: Mypy
 | 
			
		||||
      run: |
 | 
			
		||||
        mypy vllm/attention --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/core --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/distributed --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/entrypoints --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/executor --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/multimodal --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/usage --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/*.py --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/transformers_utils --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/engine  --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/worker --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/spec_decode --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/model_executor  --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/lora --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/logging --config-file pyproject.toml
 | 
			
		||||
        mypy vllm/model_executor --config-file pyproject.toml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							@ -49,13 +49,19 @@ jobs:
 | 
			
		||||
      matrix:
 | 
			
		||||
          os: ['ubuntu-20.04']
 | 
			
		||||
          python-version: ['3.8', '3.9', '3.10', '3.11']
 | 
			
		||||
          pytorch-version: ['2.1.2']  # Must be the most recent version that meets requirements.txt.
 | 
			
		||||
          pytorch-version: ['2.3.0']  # Must be the most recent version that meets requirements-cuda.txt.
 | 
			
		||||
          cuda-version: ['11.8', '12.1']
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout
 | 
			
		||||
        uses: actions/checkout@v3
 | 
			
		||||
 | 
			
		||||
      - name: Setup ccache
 | 
			
		||||
        uses: hendrikmuhs/ccache-action@v1.2
 | 
			
		||||
        with:
 | 
			
		||||
          create-symlink: true
 | 
			
		||||
          key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
 | 
			
		||||
 | 
			
		||||
      - name: Set up Linux Env
 | 
			
		||||
        if: ${{ runner.os == 'Linux' }}
 | 
			
		||||
        run: |
 | 
			
		||||
@ -76,6 +82,8 @@ jobs:
 | 
			
		||||
 | 
			
		||||
      - name: Build wheel
 | 
			
		||||
        shell: bash
 | 
			
		||||
        env:
 | 
			
		||||
          CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
 | 
			
		||||
        run: |
 | 
			
		||||
          bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
 | 
			
		||||
          wheel_name=$(ls dist/*whl | xargs -n 1 basename)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/ruff.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ruff.yml
									
									
									
									
										vendored
									
									
								
							@ -15,7 +15,7 @@ jobs:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: ["3.10"]
 | 
			
		||||
        python-version: ["3.8", "3.9", "3.10", "3.11"]
 | 
			
		||||
    steps:
 | 
			
		||||
    - uses: actions/checkout@v2
 | 
			
		||||
    - name: Set up Python ${{ matrix.python-version }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/workflows/scripts/build.sh
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/workflows/scripts/build.sh
									
									
									
									
										vendored
									
									
								
							@ -9,12 +9,13 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
 | 
			
		||||
 | 
			
		||||
# Install requirements
 | 
			
		||||
$python_executable -m pip install wheel packaging
 | 
			
		||||
$python_executable -m pip install -r requirements.txt
 | 
			
		||||
$python_executable -m pip install -r requirements-cuda.txt
 | 
			
		||||
 | 
			
		||||
# Limit the number of parallel jobs to avoid OOM
 | 
			
		||||
export MAX_JOBS=1
 | 
			
		||||
# Make sure punica is built for the release (for LoRA)
 | 
			
		||||
export VLLM_INSTALL_PUNICA_KERNELS=1
 | 
			
		||||
 | 
			
		||||
# Make sure release wheels are built for the following architectures
 | 
			
		||||
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
 | 
			
		||||
# Build
 | 
			
		||||
$python_executable setup.py bdist_wheel --dist-dir=dist
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/scripts/create_release.js
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/scripts/create_release.js
									
									
									
									
										vendored
									
									
								
							@ -8,7 +8,7 @@ module.exports = async (github, context, core) => {
 | 
			
		||||
			generate_release_notes: true,
 | 
			
		||||
			name: process.env.RELEASE_TAG,
 | 
			
		||||
			owner: context.repo.owner,
 | 
			
		||||
			prerelease: false,
 | 
			
		||||
			prerelease: true,
 | 
			
		||||
			repo: context.repo.repo,
 | 
			
		||||
			tag_name: process.env.RELEASE_TAG,
 | 
			
		||||
		});
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/yapf.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/yapf.yml
									
									
									
									
										vendored
									
									
								
							@ -14,7 +14,7 @@ jobs:
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    strategy:
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version: ["3.10"]
 | 
			
		||||
        python-version: ["3.8", "3.9", "3.10", "3.11"]
 | 
			
		||||
    steps:
 | 
			
		||||
    - uses: actions/checkout@v2
 | 
			
		||||
    - name: Set up Python ${{ matrix.python-version }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -70,6 +70,8 @@ instance/
 | 
			
		||||
 | 
			
		||||
# Sphinx documentation
 | 
			
		||||
docs/_build/
 | 
			
		||||
docs/source/getting_started/examples/*.rst
 | 
			
		||||
!**/*.template.rst
 | 
			
		||||
 | 
			
		||||
# PyBuilder
 | 
			
		||||
.pybuilder/
 | 
			
		||||
@ -181,6 +183,7 @@ _build/
 | 
			
		||||
# hip files generated by PyTorch
 | 
			
		||||
*.hip
 | 
			
		||||
*_hip*
 | 
			
		||||
hip_compat.h
 | 
			
		||||
 | 
			
		||||
# Benchmark dataset
 | 
			
		||||
*.json
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.21)
 | 
			
		||||
 | 
			
		||||
project(vllm_extensions LANGUAGES CXX)
 | 
			
		||||
 | 
			
		||||
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
 | 
			
		||||
 | 
			
		||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
 | 
			
		||||
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
 | 
			
		||||
 | 
			
		||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
 | 
			
		||||
 | 
			
		||||
@ -16,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
 | 
			
		||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
 | 
			
		||||
 | 
			
		||||
# Supported AMD GPU architectures.
 | 
			
		||||
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
 | 
			
		||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Supported/expected torch versions for CUDA/ROCm.
 | 
			
		||||
@ -28,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
 | 
			
		||||
# requirements.txt files and should be kept consistent.  The ROCm torch
 | 
			
		||||
# versions are derived from Dockerfile.rocm
 | 
			
		||||
#
 | 
			
		||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2")
 | 
			
		||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
 | 
			
		||||
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
 | 
			
		||||
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
 | 
			
		||||
 | 
			
		||||
@ -64,17 +67,17 @@ endif()
 | 
			
		||||
find_package(Torch REQUIRED)
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
 | 
			
		||||
# `libtorch_python.so` for linking against an extension. Torch's cmake
 | 
			
		||||
# configuration does not include this library (presumably since the cmake
 | 
			
		||||
# config is used for standalone C++ binaries that link against torch).
 | 
			
		||||
# The `libtorch_python.so` library defines some of the glue code between
 | 
			
		||||
# torch/python via pybind and is required by VLLM extensions for this
 | 
			
		||||
# reason. So, add it by manually with `find_library` using torch's
 | 
			
		||||
# installed library path.
 | 
			
		||||
# Forward the non-CUDA device extensions to external CMake scripts.
 | 
			
		||||
#
 | 
			
		||||
find_library(torch_python_LIBRARY torch_python PATHS
 | 
			
		||||
  "${TORCH_INSTALL_PREFIX}/lib")
 | 
			
		||||
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
 | 
			
		||||
    NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
 | 
			
		||||
    if (VLLM_TARGET_DEVICE STREQUAL "cpu")
 | 
			
		||||
        include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
 | 
			
		||||
    else()
 | 
			
		||||
        message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
 | 
			
		||||
    endif()
 | 
			
		||||
    return()
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Set up GPU language and check the torch version and warn if it isn't
 | 
			
		||||
@ -151,15 +154,47 @@ set(VLLM_EXT_SRC
 | 
			
		||||
  "csrc/layernorm_kernels.cu"
 | 
			
		||||
  "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
 | 
			
		||||
  "csrc/quantization/gptq/q_gemm.cu"
 | 
			
		||||
  "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
 | 
			
		||||
  "csrc/quantization/fp8/common.cu"
 | 
			
		||||
  "csrc/cuda_utils_kernels.cu"
 | 
			
		||||
  "csrc/moe_align_block_size_kernels.cu"
 | 
			
		||||
  "csrc/pybind.cpp")
 | 
			
		||||
  "csrc/torch_bindings.cpp")
 | 
			
		||||
 | 
			
		||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
 | 
			
		||||
  include(FetchContent)
 | 
			
		||||
  SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
 | 
			
		||||
  FetchContent_Declare(
 | 
			
		||||
        cutlass
 | 
			
		||||
        GIT_REPOSITORY https://github.com/nvidia/cutlass.git
 | 
			
		||||
        # CUTLASS 3.5.0
 | 
			
		||||
        GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
 | 
			
		||||
  )
 | 
			
		||||
  FetchContent_MakeAvailable(cutlass)
 | 
			
		||||
 | 
			
		||||
  list(APPEND VLLM_EXT_SRC
 | 
			
		||||
    "csrc/quantization/aqlm/gemm_kernels.cu"
 | 
			
		||||
    "csrc/quantization/awq/gemm_kernels.cu"
 | 
			
		||||
    "csrc/quantization/marlin/marlin_cuda_kernel.cu"
 | 
			
		||||
    "csrc/custom_all_reduce.cu")
 | 
			
		||||
    "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
 | 
			
		||||
    "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
 | 
			
		||||
    "csrc/quantization/gptq_marlin/gptq_marlin.cu"
 | 
			
		||||
    "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
 | 
			
		||||
    "csrc/custom_all_reduce.cu"
 | 
			
		||||
    "csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
 | 
			
		||||
    "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
 | 
			
		||||
    "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
 | 
			
		||||
 | 
			
		||||
  #
 | 
			
		||||
  # The CUTLASS kernels for Hopper require sm90a to be enabled.
 | 
			
		||||
  # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
 | 
			
		||||
  # That adds an extra 17MB to compiled binary, so instead we selectively enable it.
 | 
			
		||||
  if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
 | 
			
		||||
    set_source_files_properties(
 | 
			
		||||
          "csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
 | 
			
		||||
          PROPERTIES
 | 
			
		||||
          COMPILE_FLAGS
 | 
			
		||||
          "-gencode arch=compute_90a,code=sm_90a")
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
define_gpu_extension_target(
 | 
			
		||||
@ -169,6 +204,8 @@ define_gpu_extension_target(
 | 
			
		||||
  SOURCES ${VLLM_EXT_SRC}
 | 
			
		||||
  COMPILE_FLAGS ${VLLM_GPU_FLAGS}
 | 
			
		||||
  ARCHITECTURES ${VLLM_GPU_ARCHES}
 | 
			
		||||
  INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
 | 
			
		||||
  USE_SABI 3
 | 
			
		||||
  WITH_SOABI)
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
@ -176,7 +213,7 @@ define_gpu_extension_target(
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
set(VLLM_MOE_EXT_SRC
 | 
			
		||||
  "csrc/moe/moe_ops.cpp"
 | 
			
		||||
  "csrc/moe/torch_bindings.cpp"
 | 
			
		||||
  "csrc/moe/topk_softmax_kernels.cu")
 | 
			
		||||
 | 
			
		||||
define_gpu_extension_target(
 | 
			
		||||
@ -186,6 +223,7 @@ define_gpu_extension_target(
 | 
			
		||||
  SOURCES ${VLLM_MOE_EXT_SRC}
 | 
			
		||||
  COMPILE_FLAGS ${VLLM_GPU_FLAGS}
 | 
			
		||||
  ARCHITECTURES ${VLLM_GPU_ARCHES}
 | 
			
		||||
  USE_SABI 3
 | 
			
		||||
  WITH_SOABI)
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
@ -194,24 +232,13 @@ define_gpu_extension_target(
 | 
			
		||||
 | 
			
		||||
set(VLLM_PUNICA_EXT_SRC
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
 | 
			
		||||
  "csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
 | 
			
		||||
  "csrc/punica/punica_ops.cc")
 | 
			
		||||
  "csrc/punica/punica_ops.cu"
 | 
			
		||||
  "csrc/punica/torch_bindings.cpp")
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Copy GPU compilation flags+update for punica
 | 
			
		||||
@ -235,6 +262,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
 | 
			
		||||
    endif()
 | 
			
		||||
  endforeach()
 | 
			
		||||
  message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
 | 
			
		||||
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
 | 
			
		||||
  set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
 | 
			
		||||
  message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if (VLLM_PUNICA_GPU_ARCHES)
 | 
			
		||||
@ -245,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
 | 
			
		||||
    SOURCES ${VLLM_PUNICA_EXT_SRC}
 | 
			
		||||
    COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
 | 
			
		||||
    ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
 | 
			
		||||
    USE_SABI 3
 | 
			
		||||
    WITH_SOABI)
 | 
			
		||||
else()
 | 
			
		||||
  message(WARNING "Unable to create _punica_C target because none of the "
 | 
			
		||||
@ -269,9 +300,7 @@ add_custom_target(default)
 | 
			
		||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
 | 
			
		||||
  message(STATUS "Enabling C extension.")
 | 
			
		||||
  add_dependencies(default _C)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
 | 
			
		||||
  message(STATUS "Enabling moe extension.")
 | 
			
		||||
  add_dependencies(default _moe_C)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,6 @@ Express your support on Twitter if vLLM aids you, or simply offer your appreciat
 | 
			
		||||
### Build from source
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
pip install -r requirements.txt
 | 
			
		||||
pip install -e .  # This may take several minutes.
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -30,6 +29,8 @@ pip install -e .  # This may take several minutes.
 | 
			
		||||
```bash
 | 
			
		||||
pip install -r requirements-dev.txt
 | 
			
		||||
 | 
			
		||||
# linting and formatting
 | 
			
		||||
bash format.sh
 | 
			
		||||
# Static type checking
 | 
			
		||||
mypy
 | 
			
		||||
# Unit tests
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										118
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										118
									
								
								Dockerfile
									
									
									
									
									
								
							@ -1,8 +1,13 @@
 | 
			
		||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
 | 
			
		||||
# to run the OpenAI compatible server.
 | 
			
		||||
 | 
			
		||||
# Please update any changes made here to
 | 
			
		||||
# docs/source/dev/dockerfile/dockerfile.rst and
 | 
			
		||||
# docs/source/assets/dev/dockerfile-stages-dependency.png
 | 
			
		||||
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
 | 
			
		||||
# prepare basic build environment
 | 
			
		||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
 | 
			
		||||
 | 
			
		||||
RUN apt-get update -y \
 | 
			
		||||
    && apt-get install -y python3-pip git
 | 
			
		||||
@ -11,23 +16,31 @@ RUN apt-get update -y \
 | 
			
		||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
 | 
			
		||||
# this won't be needed for future versions of this docker image
 | 
			
		||||
# or future versions of triton.
 | 
			
		||||
RUN ldconfig /usr/local/cuda-12.1/compat/
 | 
			
		||||
RUN ldconfig /usr/local/cuda-12.4/compat/
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace
 | 
			
		||||
 | 
			
		||||
# install build and runtime dependencies
 | 
			
		||||
COPY requirements.txt requirements.txt
 | 
			
		||||
COPY requirements-common.txt requirements-common.txt
 | 
			
		||||
COPY requirements-cuda.txt requirements-cuda.txt
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install -r requirements.txt
 | 
			
		||||
    pip install -r requirements-cuda.txt
 | 
			
		||||
 | 
			
		||||
# install development dependencies
 | 
			
		||||
COPY requirements-dev.txt requirements-dev.txt
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install -r requirements-dev.txt
 | 
			
		||||
 | 
			
		||||
# cuda arch list used by torch
 | 
			
		||||
# can be useful for both `dev` and `test`
 | 
			
		||||
# explicitly set the list to avoid issues with torch 2.2
 | 
			
		||||
# see https://github.com/pytorch/pytorch/pull/123243
 | 
			
		||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### EXTENSION BUILD IMAGE ####################
 | 
			
		||||
#################### WHEEL BUILD IMAGE ####################
 | 
			
		||||
FROM dev AS build
 | 
			
		||||
 | 
			
		||||
# install build dependencies
 | 
			
		||||
@ -38,18 +51,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
# install compiler cache to speed up compilation leveraging local or remote caching
 | 
			
		||||
RUN apt-get update -y && apt-get install -y ccache
 | 
			
		||||
 | 
			
		||||
# copy input files
 | 
			
		||||
# files and directories related to build wheels
 | 
			
		||||
COPY csrc csrc
 | 
			
		||||
COPY setup.py setup.py
 | 
			
		||||
COPY cmake cmake
 | 
			
		||||
COPY CMakeLists.txt CMakeLists.txt
 | 
			
		||||
COPY requirements.txt requirements.txt
 | 
			
		||||
COPY requirements-common.txt requirements-common.txt
 | 
			
		||||
COPY requirements-cuda.txt requirements-cuda.txt
 | 
			
		||||
COPY pyproject.toml pyproject.toml
 | 
			
		||||
COPY vllm/__init__.py vllm/__init__.py
 | 
			
		||||
COPY vllm vllm
 | 
			
		||||
 | 
			
		||||
# cuda arch list used by torch
 | 
			
		||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
# max jobs used by Ninja to build extensions
 | 
			
		||||
ARG max_jobs=2
 | 
			
		||||
ENV MAX_JOBS=${max_jobs}
 | 
			
		||||
@ -61,77 +72,64 @@ ENV VLLM_INSTALL_PUNICA_KERNELS=1
 | 
			
		||||
 | 
			
		||||
ENV CCACHE_DIR=/root/.cache/ccache
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/ccache \
 | 
			
		||||
    python3 setup.py build_ext --inplace
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    python3 setup.py bdist_wheel --dist-dir=dist
 | 
			
		||||
 | 
			
		||||
# check the size of the wheel, we cannot upload wheels larger than 100MB
 | 
			
		||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
 | 
			
		||||
RUN python3 check-wheel-size.py dist
 | 
			
		||||
 | 
			
		||||
#################### EXTENSION Build IMAGE ####################
 | 
			
		||||
 | 
			
		||||
#################### FLASH_ATTENTION Build IMAGE ####################
 | 
			
		||||
FROM dev as flash-attn-builder
 | 
			
		||||
# max jobs used for build
 | 
			
		||||
ARG max_jobs=2
 | 
			
		||||
ENV MAX_JOBS=${max_jobs}
 | 
			
		||||
# flash attention version
 | 
			
		||||
ARG flash_attn_version=v2.5.6
 | 
			
		||||
ENV FLASH_ATTN_VERSION=${flash_attn_version}
 | 
			
		||||
#################### vLLM installation IMAGE ####################
 | 
			
		||||
# image with vLLM installed
 | 
			
		||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
 | 
			
		||||
WORKDIR /vllm-workspace
 | 
			
		||||
 | 
			
		||||
WORKDIR /usr/src/flash-attention-v2
 | 
			
		||||
RUN apt-get update -y \
 | 
			
		||||
    && apt-get install -y python3-pip git vim
 | 
			
		||||
 | 
			
		||||
# Download the wheel or build it if a pre-compiled release doesn't exist
 | 
			
		||||
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
 | 
			
		||||
    --no-build-isolation --no-deps --no-cache-dir
 | 
			
		||||
# Workaround for https://github.com/openai/triton/issues/2507 and
 | 
			
		||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
 | 
			
		||||
# this won't be needed for future versions of this docker image
 | 
			
		||||
# or future versions of triton.
 | 
			
		||||
RUN ldconfig /usr/local/cuda-12.4/compat/
 | 
			
		||||
 | 
			
		||||
# install vllm wheel first, so that torch etc will be installed
 | 
			
		||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install dist/*.whl --verbose
 | 
			
		||||
#################### vLLM installation IMAGE ####################
 | 
			
		||||
 | 
			
		||||
#################### FLASH_ATTENTION Build IMAGE ####################
 | 
			
		||||
 | 
			
		||||
#################### TEST IMAGE ####################
 | 
			
		||||
# image to run unit testing suite
 | 
			
		||||
FROM dev AS test
 | 
			
		||||
# note that this uses vllm installed by `pip`
 | 
			
		||||
FROM vllm-base AS test
 | 
			
		||||
 | 
			
		||||
# copy pytorch extensions separately to avoid having to rebuild
 | 
			
		||||
# when python code changes
 | 
			
		||||
WORKDIR /vllm-workspace
 | 
			
		||||
# ADD is used to preserve directory structure
 | 
			
		||||
ADD . /vllm-workspace/
 | 
			
		||||
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
 | 
			
		||||
# Install flash attention (from pre-built wheel)
 | 
			
		||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
 | 
			
		||||
    pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
 | 
			
		||||
# ignore build dependencies installation because we are using pre-complied extensions
 | 
			
		||||
RUN rm pyproject.toml
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
 | 
			
		||||
#################### TEST IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### RUNTIME BASE IMAGE ####################
 | 
			
		||||
# We used base cuda image because pytorch installs its own cuda libraries.
 | 
			
		||||
# However pynccl depends on cuda libraries so we had to switch to the runtime image
 | 
			
		||||
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
 | 
			
		||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
 | 
			
		||||
 | 
			
		||||
# libnccl required for ray
 | 
			
		||||
RUN apt-get update -y \
 | 
			
		||||
    && apt-get install -y python3-pip
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace
 | 
			
		||||
COPY requirements.txt requirements.txt
 | 
			
		||||
# install development dependencies (for testing)
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install -r requirements.txt
 | 
			
		||||
    pip install -r requirements-dev.txt
 | 
			
		||||
 | 
			
		||||
# Install flash attention (from pre-built wheel)
 | 
			
		||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
 | 
			
		||||
    pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
 | 
			
		||||
 | 
			
		||||
#################### RUNTIME BASE IMAGE ####################
 | 
			
		||||
# doc requires source code
 | 
			
		||||
# we hide them inside `test_docs/` , so that this source code
 | 
			
		||||
# will not be imported by other tests
 | 
			
		||||
RUN mkdir test_docs
 | 
			
		||||
RUN mv docs test_docs/
 | 
			
		||||
RUN mv vllm test_docs/
 | 
			
		||||
 | 
			
		||||
#################### TEST IMAGE ####################
 | 
			
		||||
 | 
			
		||||
#################### OPENAI API SERVER ####################
 | 
			
		||||
# openai api server alternative
 | 
			
		||||
FROM vllm-base AS vllm-openai
 | 
			
		||||
 | 
			
		||||
# install additional dependencies for openai api server
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install accelerate hf_transfer modelscope
 | 
			
		||||
 | 
			
		||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
 | 
			
		||||
COPY vllm vllm
 | 
			
		||||
 | 
			
		||||
ENV VLLM_USAGE_SOURCE production-docker-image
 | 
			
		||||
 | 
			
		||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										26
									
								
								Dockerfile.cpu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								Dockerfile.cpu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,26 @@
 | 
			
		||||
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
 | 
			
		||||
 | 
			
		||||
FROM ubuntu:22.04 AS cpu-test-1
 | 
			
		||||
 | 
			
		||||
RUN apt-get update  -y \
 | 
			
		||||
    && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
 | 
			
		||||
    && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
 | 
			
		||||
 | 
			
		||||
RUN pip install --upgrade pip \
 | 
			
		||||
    && pip install wheel packaging ninja "setuptools>=49.4.0" numpy
 | 
			
		||||
 | 
			
		||||
FROM cpu-test-1 AS build
 | 
			
		||||
 | 
			
		||||
COPY ./ /workspace/vllm
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace/vllm
 | 
			
		||||
 | 
			
		||||
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
 | 
			
		||||
 | 
			
		||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace/
 | 
			
		||||
 | 
			
		||||
RUN ln -s /workspace/vllm/tests  && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
 | 
			
		||||
 | 
			
		||||
CMD ["/bin/bash"]
 | 
			
		||||
							
								
								
									
										36
									
								
								Dockerfile.neuron
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								Dockerfile.neuron
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,36 @@
 | 
			
		||||
# default base image
 | 
			
		||||
ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
 | 
			
		||||
 | 
			
		||||
FROM $BASE_IMAGE
 | 
			
		||||
 | 
			
		||||
RUN echo "Base image is $BASE_IMAGE"
 | 
			
		||||
 | 
			
		||||
# Install some basic utilities
 | 
			
		||||
RUN apt-get update && apt-get install python3 python3-pip -y
 | 
			
		||||
 | 
			
		||||
### Mount Point ###
 | 
			
		||||
# When launching the container, mount the code directory to /app
 | 
			
		||||
ARG APP_MOUNT=/app
 | 
			
		||||
VOLUME [ ${APP_MOUNT} ]
 | 
			
		||||
WORKDIR ${APP_MOUNT}
 | 
			
		||||
 | 
			
		||||
RUN python3 -m pip install --upgrade pip
 | 
			
		||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
 | 
			
		||||
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
 | 
			
		||||
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
 | 
			
		||||
RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
 | 
			
		||||
 | 
			
		||||
COPY ./vllm /app/vllm/vllm
 | 
			
		||||
COPY ./setup.py /app/vllm/setup.py
 | 
			
		||||
COPY ./requirements-common.txt /app/vllm/requirements-common.txt
 | 
			
		||||
COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
 | 
			
		||||
 | 
			
		||||
RUN cd /app/vllm \
 | 
			
		||||
    && python3 -m pip install -U -r requirements-neuron.txt
 | 
			
		||||
 | 
			
		||||
ENV VLLM_TARGET_DEVICE neuron
 | 
			
		||||
RUN cd /app/vllm \
 | 
			
		||||
    && pip install -e . \
 | 
			
		||||
    && cd ..
 | 
			
		||||
 | 
			
		||||
CMD ["/bin/bash"]
 | 
			
		||||
@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
 | 
			
		||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
 | 
			
		||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
 | 
			
		||||
 | 
			
		||||
ARG FA_BRANCH="3d2b6f5"
 | 
			
		||||
ARG FA_BRANCH="ae7928c"
 | 
			
		||||
RUN echo "FA_BRANCH is $FA_BRANCH"
 | 
			
		||||
 | 
			
		||||
# whether to build flash-attention
 | 
			
		||||
@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
 | 
			
		||||
# In that case, we need to use the python reference attention implementation in vllm
 | 
			
		||||
ARG BUILD_FA="1"
 | 
			
		||||
 | 
			
		||||
# whether to build triton on rocm
 | 
			
		||||
ARG BUILD_TRITON="1"
 | 
			
		||||
 | 
			
		||||
# Install some basic utilities
 | 
			
		||||
RUN apt-get update && apt-get install python3 python3-pip -y
 | 
			
		||||
 | 
			
		||||
@ -43,7 +46,7 @@ RUN apt-get update && apt-get install -y \
 | 
			
		||||
 | 
			
		||||
### Mount Point ###
 | 
			
		||||
# When launching the container, mount the code directory to /app
 | 
			
		||||
ARG APP_MOUNT=/app
 | 
			
		||||
ARG APP_MOUNT=/vllm-workspace
 | 
			
		||||
VOLUME [ ${APP_MOUNT} ]
 | 
			
		||||
WORKDIR ${APP_MOUNT}
 | 
			
		||||
 | 
			
		||||
@ -75,21 +78,38 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
 | 
			
		||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
 | 
			
		||||
    rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
 | 
			
		||||
 | 
			
		||||
COPY ./ /app/vllm
 | 
			
		||||
# build triton
 | 
			
		||||
RUN if [ "$BUILD_TRITON" = "1" ]; then \
 | 
			
		||||
    mkdir -p libs \
 | 
			
		||||
    && cd libs \
 | 
			
		||||
    && pip uninstall -y triton \
 | 
			
		||||
    && git clone https://github.com/ROCm/triton.git \
 | 
			
		||||
    && cd triton/python \
 | 
			
		||||
    && pip3 install . \
 | 
			
		||||
    && cd ../..; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
RUN python3 -m pip install --upgrade pip
 | 
			
		||||
RUN python3 -m pip install xformers==0.0.23 --no-deps
 | 
			
		||||
WORKDIR /vllm-workspace
 | 
			
		||||
COPY . .
 | 
			
		||||
 | 
			
		||||
RUN cd /app \
 | 
			
		||||
    && cd vllm \
 | 
			
		||||
    && pip install -U -r requirements-rocm.txt \
 | 
			
		||||
    && if [ "$BUILD_FA" = "1" ]; then \
 | 
			
		||||
       bash patch_xformers.rocm.sh; fi \
 | 
			
		||||
    && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
 | 
			
		||||
#RUN python3 -m pip install pynvml # to be removed eventually
 | 
			
		||||
RUN python3 -m pip install --upgrade pip numba
 | 
			
		||||
 | 
			
		||||
# make sure punica kernels are built (for LoRA)
 | 
			
		||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
 | 
			
		||||
# Workaround for ray >= 2.10.0
 | 
			
		||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
 | 
			
		||||
 | 
			
		||||
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install -U -r requirements-rocm.txt \
 | 
			
		||||
    && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
 | 
			
		||||
    && python3 setup.py install \
 | 
			
		||||
    && cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \
 | 
			
		||||
    && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \
 | 
			
		||||
    && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \
 | 
			
		||||
    && cd ..
 | 
			
		||||
 | 
			
		||||
RUN python3 -m pip install --upgrade pip
 | 
			
		||||
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
 | 
			
		||||
 | 
			
		||||
CMD ["/bin/bash"]
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,9 @@
 | 
			
		||||
include LICENSE
 | 
			
		||||
include requirements.txt
 | 
			
		||||
include requirements-common.txt
 | 
			
		||||
include requirements-cuda.txt
 | 
			
		||||
include requirements-rocm.txt
 | 
			
		||||
include requirements-neuron.txt
 | 
			
		||||
include requirements-cpu.txt
 | 
			
		||||
include CMakeLists.txt
 | 
			
		||||
 | 
			
		||||
recursive-include cmake *
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										86
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										86
									
								
								README.md
									
									
									
									
									
								
							@ -16,16 +16,24 @@ Easy, fast, and cheap LLM serving for everyone
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
**The Third vLLM Bay Area Meetup (April 2nd 6pm-8:30pm PT)**
 | 
			
		||||
**Ray Summit CPF is Open (June 4th to June 20th)!**
 | 
			
		||||
 | 
			
		||||
We are thrilled to announce our third vLLM Meetup!
 | 
			
		||||
There will be a track for vLLM at the Ray Summit (09/30-10/02, SF) this year!
 | 
			
		||||
If you have cool projects related to vLLM or LLM inference, we would love to see your proposals.
 | 
			
		||||
This will be a great chance for everyone in the community to get together and learn.
 | 
			
		||||
Please submit your proposal [here](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/eventsite)
 | 
			
		||||
 | 
			
		||||
**The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
 | 
			
		||||
 | 
			
		||||
We are thrilled to announce our fourth vLLM Meetup!
 | 
			
		||||
The vLLM team will share recent updates and roadmap.
 | 
			
		||||
We will also have vLLM collaborators from Roblox coming up to the stage to discuss their experience in deploying LLMs with vLLM.
 | 
			
		||||
Please register [here](https://robloxandvllmmeetup2024.splashthat.com/) and join us!
 | 
			
		||||
We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
 | 
			
		||||
Please register [here](https://lu.ma/agivllm) and join us!
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
*Latest News* 🔥
 | 
			
		||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
 | 
			
		||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
 | 
			
		||||
- [2024/01] Added ROCm 6.0 support to vLLM.
 | 
			
		||||
- [2023/12] Added ROCm 5.7 support to vLLM.
 | 
			
		||||
@ -61,39 +69,14 @@ vLLM is flexible and easy to use with:
 | 
			
		||||
- (Experimental) Prefix caching support
 | 
			
		||||
- (Experimental) Multi-lora support
 | 
			
		||||
 | 
			
		||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
 | 
			
		||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
 | 
			
		||||
- Transformer-like LLMs (e.g., Llama)
 | 
			
		||||
- Mixture-of-Expert LLMs (e.g., Mixtral)
 | 
			
		||||
- Multi-modal LLMs (e.g., LLaVA)
 | 
			
		||||
 | 
			
		||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
 | 
			
		||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
 | 
			
		||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
 | 
			
		||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
 | 
			
		||||
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
 | 
			
		||||
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
 | 
			
		||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
 | 
			
		||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
 | 
			
		||||
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
 | 
			
		||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
 | 
			
		||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
 | 
			
		||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
 | 
			
		||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
 | 
			
		||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
 | 
			
		||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
 | 
			
		||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
 | 
			
		||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
 | 
			
		||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
 | 
			
		||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
 | 
			
		||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
 | 
			
		||||
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
 | 
			
		||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
 | 
			
		||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
 | 
			
		||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
 | 
			
		||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
 | 
			
		||||
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
 | 
			
		||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
 | 
			
		||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
 | 
			
		||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
 | 
			
		||||
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
 | 
			
		||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
 | 
			
		||||
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
 | 
			
		||||
 | 
			
		||||
## Getting Started
 | 
			
		||||
 | 
			
		||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
 | 
			
		||||
 | 
			
		||||
@ -101,9 +84,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
 | 
			
		||||
pip install vllm
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Getting Started
 | 
			
		||||
 | 
			
		||||
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
 | 
			
		||||
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
 | 
			
		||||
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
 | 
			
		||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
 | 
			
		||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
 | 
			
		||||
@ -113,6 +94,33 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
 | 
			
		||||
We welcome and value any contributions and collaborations.
 | 
			
		||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
 | 
			
		||||
 | 
			
		||||
## Sponsors
 | 
			
		||||
 | 
			
		||||
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
 | 
			
		||||
 | 
			
		||||
<!-- Note: Please sort them in alphabetical order. -->
 | 
			
		||||
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
 | 
			
		||||
 | 
			
		||||
- a16z
 | 
			
		||||
- AMD
 | 
			
		||||
- Anyscale
 | 
			
		||||
- AWS
 | 
			
		||||
- Crusoe Cloud
 | 
			
		||||
- Databricks
 | 
			
		||||
- DeepInfra
 | 
			
		||||
- Dropbox
 | 
			
		||||
- Lambda Lab
 | 
			
		||||
- NVIDIA
 | 
			
		||||
- Replicate
 | 
			
		||||
- Roblox
 | 
			
		||||
- RunPod
 | 
			
		||||
- Sequoia Capital
 | 
			
		||||
- Trainy
 | 
			
		||||
- UC Berkeley
 | 
			
		||||
- UC San Diego
 | 
			
		||||
 | 
			
		||||
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
 | 
			
		||||
 | 
			
		||||
## Citation
 | 
			
		||||
 | 
			
		||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
 | 
			
		||||
 | 
			
		||||
@ -27,8 +27,8 @@ class RequestFuncInput:
 | 
			
		||||
class RequestFuncOutput:
 | 
			
		||||
    generated_text: str = ""
 | 
			
		||||
    success: bool = False
 | 
			
		||||
    latency: float = 0
 | 
			
		||||
    ttft: float = 0  # Time to first token
 | 
			
		||||
    latency: float = 0.0
 | 
			
		||||
    ttft: float = 0.0  # Time to first token
 | 
			
		||||
    itl: List[float] = field(
 | 
			
		||||
        default_factory=list)  # List of inter-token latencies
 | 
			
		||||
    prompt_len: int = 0
 | 
			
		||||
@ -58,23 +58,24 @@ async def async_request_tgi(
 | 
			
		||||
        output = RequestFuncOutput()
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        ttft = 0
 | 
			
		||||
        ttft = 0.0
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        most_recent_timestamp = st
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    async for chunk in response.content:
 | 
			
		||||
                        chunk = chunk.strip()
 | 
			
		||||
                        if not chunk:
 | 
			
		||||
                    async for chunk_bytes in response.content:
 | 
			
		||||
                        chunk_bytes = chunk_bytes.strip()
 | 
			
		||||
                        if not chunk_bytes:
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        chunk = remove_prefix(chunk.decode("utf-8"), "data:")
 | 
			
		||||
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
 | 
			
		||||
                                              "data:")
 | 
			
		||||
 | 
			
		||||
                        data = json.loads(chunk)
 | 
			
		||||
                        timestamp = time.perf_counter()
 | 
			
		||||
                        # First token
 | 
			
		||||
                        if ttft == 0:
 | 
			
		||||
                        if ttft == 0.0:
 | 
			
		||||
                            ttft = time.perf_counter() - st
 | 
			
		||||
                            output.ttft = ttft
 | 
			
		||||
 | 
			
		||||
@ -88,6 +89,9 @@ async def async_request_tgi(
 | 
			
		||||
                    output.latency = most_recent_timestamp - st
 | 
			
		||||
                    output.success = True
 | 
			
		||||
                    output.generated_text = data["generated_text"]
 | 
			
		||||
                else:
 | 
			
		||||
                    output.error = response.reason or ""
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except Exception:
 | 
			
		||||
            output.success = False
 | 
			
		||||
            exc_info = sys.exc_info()
 | 
			
		||||
@ -119,23 +123,25 @@ async def async_request_trt_llm(
 | 
			
		||||
        output = RequestFuncOutput()
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        ttft = 0
 | 
			
		||||
        ttft = 0.0
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        most_recent_timestamp = st
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    async for chunk in response.content:
 | 
			
		||||
                        chunk = chunk.strip()
 | 
			
		||||
                        if not chunk:
 | 
			
		||||
                    async for chunk_bytes in response.content:
 | 
			
		||||
                        chunk_bytes = chunk_bytes.strip()
 | 
			
		||||
                        if not chunk_bytes:
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        chunk = remove_prefix(chunk.decode("utf-8"), "data:")
 | 
			
		||||
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
 | 
			
		||||
                                              "data:")
 | 
			
		||||
 | 
			
		||||
                        data = json.loads(chunk)
 | 
			
		||||
                        output.generated_text += data["text_output"]
 | 
			
		||||
                        timestamp = time.perf_counter()
 | 
			
		||||
                        # First token
 | 
			
		||||
                        if ttft == 0:
 | 
			
		||||
                        if ttft == 0.0:
 | 
			
		||||
                            ttft = time.perf_counter() - st
 | 
			
		||||
                            output.ttft = ttft
 | 
			
		||||
 | 
			
		||||
@ -147,11 +153,10 @@ async def async_request_trt_llm(
 | 
			
		||||
                        most_recent_timestamp = timestamp
 | 
			
		||||
 | 
			
		||||
                    output.latency = most_recent_timestamp - st
 | 
			
		||||
                    output.generated_text = json.loads(data)["text_output"]
 | 
			
		||||
                    output.success = True
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    output.error = response.reason
 | 
			
		||||
                    output.error = response.reason or ""
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except Exception:
 | 
			
		||||
            output.success = False
 | 
			
		||||
@ -195,7 +200,7 @@ async def async_request_deepspeed_mii(
 | 
			
		||||
                    output.generated_text = parsed_resp["text"][0]
 | 
			
		||||
                    output.success = True
 | 
			
		||||
                else:
 | 
			
		||||
                    output.error = response.reason
 | 
			
		||||
                    output.error = response.reason or ""
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except Exception:
 | 
			
		||||
            output.success = False
 | 
			
		||||
@ -234,19 +239,20 @@ async def async_request_openai_completions(
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        generated_text = ""
 | 
			
		||||
        ttft = 0
 | 
			
		||||
        ttft = 0.0
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        most_recent_timestamp = st
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload,
 | 
			
		||||
                                    headers=headers) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    async for chunk in response.content:
 | 
			
		||||
                        chunk = chunk.strip()
 | 
			
		||||
                        if not chunk:
 | 
			
		||||
                    async for chunk_bytes in response.content:
 | 
			
		||||
                        chunk_bytes = chunk_bytes.strip()
 | 
			
		||||
                        if not chunk_bytes:
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
 | 
			
		||||
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
 | 
			
		||||
                                              "data: ")
 | 
			
		||||
                        if chunk == "[DONE]":
 | 
			
		||||
                            latency = time.perf_counter() - st
 | 
			
		||||
                        else:
 | 
			
		||||
@ -255,7 +261,7 @@ async def async_request_openai_completions(
 | 
			
		||||
                            if data["choices"][0]["text"]:
 | 
			
		||||
                                timestamp = time.perf_counter()
 | 
			
		||||
                                # First token
 | 
			
		||||
                                if ttft == 0:
 | 
			
		||||
                                if ttft == 0.0:
 | 
			
		||||
                                    ttft = time.perf_counter() - st
 | 
			
		||||
                                    output.ttft = ttft
 | 
			
		||||
 | 
			
		||||
@ -273,6 +279,9 @@ async def async_request_openai_completions(
 | 
			
		||||
                    output.generated_text = generated_text
 | 
			
		||||
                    output.success = True
 | 
			
		||||
                    output.latency = latency
 | 
			
		||||
                else:
 | 
			
		||||
                    output.error = response.reason or ""
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except Exception:
 | 
			
		||||
            output.success = False
 | 
			
		||||
            exc_info = sys.exc_info()
 | 
			
		||||
@ -315,28 +324,30 @@ async def async_request_openai_chat_completions(
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        generated_text = ""
 | 
			
		||||
        ttft = 0
 | 
			
		||||
        ttft = 0.0
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        most_recent_timestamp = st
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload,
 | 
			
		||||
                                    headers=headers) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    async for chunk in response.content:
 | 
			
		||||
                        chunk = chunk.strip()
 | 
			
		||||
                        if not chunk:
 | 
			
		||||
                    async for chunk_bytes in response.content:
 | 
			
		||||
                        chunk_bytes = chunk_bytes.strip()
 | 
			
		||||
                        if not chunk_bytes:
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
 | 
			
		||||
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
 | 
			
		||||
                                              "data: ")
 | 
			
		||||
                        if chunk == "[DONE]":
 | 
			
		||||
                            latency = time.perf_counter() - st
 | 
			
		||||
                        else:
 | 
			
		||||
                            timestamp = time.perf_counter()
 | 
			
		||||
                            data = json.loads(chunk)
 | 
			
		||||
 | 
			
		||||
                            if "content" in data["choices"][0]["delta"]:
 | 
			
		||||
                            delta = data["choices"][0]["delta"]
 | 
			
		||||
                            if delta.get("content", None):
 | 
			
		||||
                                # First token
 | 
			
		||||
                                if ttft == 0:
 | 
			
		||||
                                if ttft == 0.0:
 | 
			
		||||
                                    ttft = time.perf_counter() - st
 | 
			
		||||
                                    output.ttft = ttft
 | 
			
		||||
 | 
			
		||||
@ -345,8 +356,7 @@ async def async_request_openai_chat_completions(
 | 
			
		||||
                                    output.itl.append(timestamp -
 | 
			
		||||
                                                      most_recent_timestamp)
 | 
			
		||||
 | 
			
		||||
                                generated_text += data["choices"][0]["delta"][
 | 
			
		||||
                                    "content"]
 | 
			
		||||
                                generated_text += delta["content"]
 | 
			
		||||
 | 
			
		||||
                            most_recent_timestamp = timestamp
 | 
			
		||||
 | 
			
		||||
@ -354,7 +364,7 @@ async def async_request_openai_chat_completions(
 | 
			
		||||
                    output.success = True
 | 
			
		||||
                    output.latency = latency
 | 
			
		||||
                else:
 | 
			
		||||
                    output.error = response.reason
 | 
			
		||||
                    output.error = response.reason or ""
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except Exception:
 | 
			
		||||
            output.success = False
 | 
			
		||||
 | 
			
		||||
@ -1,14 +1,17 @@
 | 
			
		||||
"""Benchmark the latency of processing a single batch of requests."""
 | 
			
		||||
import argparse
 | 
			
		||||
import json
 | 
			
		||||
import time
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from vllm import LLM, SamplingParams
 | 
			
		||||
from vllm.inputs import PromptStrictInputs
 | 
			
		||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args: argparse.Namespace):
 | 
			
		||||
@ -17,6 +20,8 @@ def main(args: argparse.Namespace):
 | 
			
		||||
    # NOTE(woosuk): If the request cannot be processed in a single batch,
 | 
			
		||||
    # the engine will automatically process the request in multiple batches.
 | 
			
		||||
    llm = LLM(model=args.model,
 | 
			
		||||
              speculative_model=args.speculative_model,
 | 
			
		||||
              num_speculative_tokens=args.num_speculative_tokens,
 | 
			
		||||
              tokenizer=args.tokenizer,
 | 
			
		||||
              quantization=args.quantization,
 | 
			
		||||
              tensor_parallel_size=args.tensor_parallel_size,
 | 
			
		||||
@ -24,11 +29,15 @@ def main(args: argparse.Namespace):
 | 
			
		||||
              dtype=args.dtype,
 | 
			
		||||
              enforce_eager=args.enforce_eager,
 | 
			
		||||
              kv_cache_dtype=args.kv_cache_dtype,
 | 
			
		||||
              quantization_param_path=args.quantization_param_path,
 | 
			
		||||
              device=args.device,
 | 
			
		||||
              ray_workers_use_nsight=args.ray_workers_use_nsight,
 | 
			
		||||
              use_v2_block_manager=args.use_v2_block_manager,
 | 
			
		||||
              enable_chunked_prefill=args.enable_chunked_prefill,
 | 
			
		||||
              download_dir=args.download_dir,
 | 
			
		||||
              block_size=args.block_size)
 | 
			
		||||
              block_size=args.block_size,
 | 
			
		||||
              gpu_memory_utilization=args.gpu_memory_utilization,
 | 
			
		||||
              distributed_executor_backend=args.distributed_executor_backend)
 | 
			
		||||
 | 
			
		||||
    sampling_params = SamplingParams(
 | 
			
		||||
        n=args.n,
 | 
			
		||||
@ -42,7 +51,9 @@ def main(args: argparse.Namespace):
 | 
			
		||||
    dummy_prompt_token_ids = np.random.randint(10000,
 | 
			
		||||
                                               size=(args.batch_size,
 | 
			
		||||
                                                     args.input_len))
 | 
			
		||||
    dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
 | 
			
		||||
    dummy_inputs: List[PromptStrictInputs] = [{
 | 
			
		||||
        "prompt_token_ids": batch
 | 
			
		||||
    } for batch in dummy_prompt_token_ids.tolist()]
 | 
			
		||||
 | 
			
		||||
    def run_to_completion(profile_dir: Optional[str] = None):
 | 
			
		||||
        if profile_dir:
 | 
			
		||||
@ -53,13 +64,13 @@ def main(args: argparse.Namespace):
 | 
			
		||||
                    ],
 | 
			
		||||
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
 | 
			
		||||
                        str(profile_dir))) as p:
 | 
			
		||||
                llm.generate(prompt_token_ids=dummy_prompt_token_ids,
 | 
			
		||||
                llm.generate(dummy_inputs,
 | 
			
		||||
                             sampling_params=sampling_params,
 | 
			
		||||
                             use_tqdm=False)
 | 
			
		||||
            print(p.key_averages())
 | 
			
		||||
        else:
 | 
			
		||||
            start_time = time.perf_counter()
 | 
			
		||||
            llm.generate(prompt_token_ids=dummy_prompt_token_ids,
 | 
			
		||||
            llm.generate(dummy_inputs,
 | 
			
		||||
                         sampling_params=sampling_params,
 | 
			
		||||
                         use_tqdm=False)
 | 
			
		||||
            end_time = time.perf_counter()
 | 
			
		||||
@ -67,7 +78,8 @@ def main(args: argparse.Namespace):
 | 
			
		||||
            return latency
 | 
			
		||||
 | 
			
		||||
    print("Warming up...")
 | 
			
		||||
    run_to_completion(profile_dir=None)
 | 
			
		||||
    for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
 | 
			
		||||
        run_to_completion(profile_dir=None)
 | 
			
		||||
 | 
			
		||||
    if args.profile:
 | 
			
		||||
        profile_dir = args.profile_result_dir
 | 
			
		||||
@ -83,7 +95,22 @@ def main(args: argparse.Namespace):
 | 
			
		||||
    latencies = []
 | 
			
		||||
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
 | 
			
		||||
        latencies.append(run_to_completion(profile_dir=None))
 | 
			
		||||
    latencies = np.array(latencies)
 | 
			
		||||
    percentages = [10, 25, 50, 75, 90]
 | 
			
		||||
    percentiles = np.percentile(latencies, percentages)
 | 
			
		||||
    print(f'Avg latency: {np.mean(latencies)} seconds')
 | 
			
		||||
    for percentage, percentile in zip(percentages, percentiles):
 | 
			
		||||
        print(f'{percentage}% percentile latency: {percentile} seconds')
 | 
			
		||||
 | 
			
		||||
    # Output JSON results if specified
 | 
			
		||||
    if args.output_json:
 | 
			
		||||
        results = {
 | 
			
		||||
            "avg_latency": np.mean(latencies),
 | 
			
		||||
            "latencies": latencies.tolist(),
 | 
			
		||||
            "percentiles": dict(zip(percentages, percentiles.tolist())),
 | 
			
		||||
        }
 | 
			
		||||
        with open(args.output_json, "w") as f:
 | 
			
		||||
            json.dump(results, f, indent=4)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
@ -91,10 +118,12 @@ if __name__ == '__main__':
 | 
			
		||||
        description='Benchmark the latency of processing a single batch of '
 | 
			
		||||
        'requests till completion.')
 | 
			
		||||
    parser.add_argument('--model', type=str, default='facebook/opt-125m')
 | 
			
		||||
    parser.add_argument('--speculative-model', type=str, default=None)
 | 
			
		||||
    parser.add_argument('--num-speculative-tokens', type=int, default=None)
 | 
			
		||||
    parser.add_argument('--tokenizer', type=str, default=None)
 | 
			
		||||
    parser.add_argument('--quantization',
 | 
			
		||||
                        '-q',
 | 
			
		||||
                        choices=['awq', 'gptq', 'squeezellm', None],
 | 
			
		||||
                        choices=[*QUANTIZATION_METHODS, None],
 | 
			
		||||
                        default=None)
 | 
			
		||||
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--input-len', type=int, default=32)
 | 
			
		||||
@ -105,9 +134,13 @@ if __name__ == '__main__':
 | 
			
		||||
                        default=1,
 | 
			
		||||
                        help='Number of generated sequences per prompt.')
 | 
			
		||||
    parser.add_argument('--use-beam-search', action='store_true')
 | 
			
		||||
    parser.add_argument('--num-iters-warmup',
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        default=10,
 | 
			
		||||
                        help='Number of iterations to run for warmup.')
 | 
			
		||||
    parser.add_argument('--num-iters',
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        default=3,
 | 
			
		||||
                        default=30,
 | 
			
		||||
                        help='Number of iterations to run.')
 | 
			
		||||
    parser.add_argument('--trust-remote-code',
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
@ -125,12 +158,23 @@ if __name__ == '__main__':
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
                        help='enforce eager mode and disable CUDA graph')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--kv-cache-dtype",
 | 
			
		||||
        '--kv-cache-dtype',
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=['auto', 'fp8_e5m2'],
 | 
			
		||||
        default='auto',
 | 
			
		||||
        help=
 | 
			
		||||
        'Data type for kv cache storage. If "auto", will use model data type.')
 | 
			
		||||
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
 | 
			
		||||
        default="auto",
 | 
			
		||||
        help='Data type for kv cache storage. If "auto", will use model '
 | 
			
		||||
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
 | 
			
		||||
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--quantization-param-path',
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help='Path to the JSON file containing the KV cache scaling factors. '
 | 
			
		||||
        'This should generally be supplied, when KV cache dtype is FP8. '
 | 
			
		||||
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
 | 
			
		||||
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
 | 
			
		||||
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
 | 
			
		||||
        'instead supported for common inference criteria.')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--profile',
 | 
			
		||||
        action='store_true',
 | 
			
		||||
@ -145,18 +189,18 @@ if __name__ == '__main__':
 | 
			
		||||
        "--device",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="cuda",
 | 
			
		||||
        choices=["cuda"],
 | 
			
		||||
        help='device type for vLLM execution, supporting CUDA only currently.')
 | 
			
		||||
        choices=["cuda", "cpu"],
 | 
			
		||||
        help='device type for vLLM execution, supporting CUDA and CPU.')
 | 
			
		||||
    parser.add_argument('--block-size',
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        default=16,
 | 
			
		||||
                        help='block size of key/value cache')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--enable-chunked-prefill',
 | 
			
		||||
        type=bool,
 | 
			
		||||
        default=False,
 | 
			
		||||
        action='store_true',
 | 
			
		||||
        help='If True, the prefill requests can be chunked based on the '
 | 
			
		||||
        'max_num_batched_tokens')
 | 
			
		||||
    parser.add_argument('--use-v2-block-manager', action='store_true')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--ray-workers-use-nsight",
 | 
			
		||||
        action='store_true',
 | 
			
		||||
@ -167,5 +211,23 @@ if __name__ == '__main__':
 | 
			
		||||
                        default=None,
 | 
			
		||||
                        help='directory to download and load the weights, '
 | 
			
		||||
                        'default to the default cache dir of huggingface')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--output-json',
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help='Path to save the latency results in JSON format.')
 | 
			
		||||
    parser.add_argument('--gpu-memory-utilization',
 | 
			
		||||
                        type=float,
 | 
			
		||||
                        default=0.9,
 | 
			
		||||
                        help='the fraction of GPU memory to be used for '
 | 
			
		||||
                        'the model executor, which can range from 0 to 1.'
 | 
			
		||||
                        'If unspecified, will use the default value of 0.9.')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--distributed-executor-backend',
 | 
			
		||||
        choices=['ray', 'mp'],
 | 
			
		||||
        default=None,
 | 
			
		||||
        help='Backend to use for distributed serving. When more than 1 GPU '
 | 
			
		||||
        'is used, will be automatically set to "ray" if installed '
 | 
			
		||||
        'or "mp" (multiprocessing) otherwise.')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
 | 
			
		||||
@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args):
 | 
			
		||||
    llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
 | 
			
		||||
    llm = LLM(model=args.model,
 | 
			
		||||
              tokenizer_mode='auto',
 | 
			
		||||
              trust_remote_code=True,
 | 
			
		||||
              enforce_eager=True,
 | 
			
		||||
              use_v2_block_manager=args.use_v2_block_manager,
 | 
			
		||||
              tensor_parallel_size=args.tensor_parallel_size,
 | 
			
		||||
              enable_prefix_caching=args.enable_prefix_caching)
 | 
			
		||||
 | 
			
		||||
    num_prompts = 100
 | 
			
		||||
    prompts = [PROMPT] * num_prompts
 | 
			
		||||
    sampling_params = SamplingParams(temperature=0, max_tokens=100)
 | 
			
		||||
    sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
 | 
			
		||||
 | 
			
		||||
    print("------warm up------")
 | 
			
		||||
    test_prefix(
 | 
			
		||||
        llm=llm,
 | 
			
		||||
        prompts=prompts[:1],
 | 
			
		||||
        prompts=prompts,
 | 
			
		||||
        sampling_params=sampling_params,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -45,8 +47,16 @@ if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description='Benchmark the performance with or without automatic '
 | 
			
		||||
        'prefix caching.')
 | 
			
		||||
    parser.add_argument('--model',
 | 
			
		||||
                        type=str,
 | 
			
		||||
                        default='baichuan-inc/Baichuan2-13B-Chat')
 | 
			
		||||
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--output-len', type=int, default=10)
 | 
			
		||||
    parser.add_argument('--enable-prefix-caching',
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
                        help='enable prefix caching')
 | 
			
		||||
    parser.add_argument('--use-v2-block-manager',
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
                        help='Use BlockSpaceMangerV2')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,10 @@ On the client side, run:
 | 
			
		||||
        --dataset-path <path to dataset> \
 | 
			
		||||
        --request-rate <request_rate> \ # By default <request_rate> is inf
 | 
			
		||||
        --num-prompts <num_prompts> # By default <num_prompts> is 1000
 | 
			
		||||
        
 | 
			
		||||
    when using tgi backend, add
 | 
			
		||||
        --endpoint /generate_stream
 | 
			
		||||
    to the end of the command above.
 | 
			
		||||
"""
 | 
			
		||||
import argparse
 | 
			
		||||
import asyncio
 | 
			
		||||
@ -27,7 +31,7 @@ import time
 | 
			
		||||
import warnings
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import AsyncGenerator, List, Tuple
 | 
			
		||||
from typing import AsyncGenerator, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
 | 
			
		||||
@ -52,13 +56,20 @@ class BenchmarkMetrics:
 | 
			
		||||
    mean_tpot_ms: float
 | 
			
		||||
    median_tpot_ms: float
 | 
			
		||||
    p99_tpot_ms: float
 | 
			
		||||
    mean_itl_ms: float
 | 
			
		||||
    median_itl_ms: float
 | 
			
		||||
    p99_itl_ms: float
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sample_sharegpt_requests(
 | 
			
		||||
    dataset_path: str,
 | 
			
		||||
    num_requests: int,
 | 
			
		||||
    tokenizer: PreTrainedTokenizerBase,
 | 
			
		||||
    fixed_output_len: Optional[int] = None,
 | 
			
		||||
) -> List[Tuple[str, int, int]]:
 | 
			
		||||
    if fixed_output_len is not None and fixed_output_len < 4:
 | 
			
		||||
        raise ValueError("output_len too small")
 | 
			
		||||
 | 
			
		||||
    # Load the dataset.
 | 
			
		||||
    with open(dataset_path) as f:
 | 
			
		||||
        dataset = json.load(f)
 | 
			
		||||
@ -68,38 +79,32 @@ def sample_sharegpt_requests(
 | 
			
		||||
    dataset = [(data["conversations"][0]["value"],
 | 
			
		||||
                data["conversations"][1]["value"]) for data in dataset]
 | 
			
		||||
 | 
			
		||||
    # some of these will be filtered out, so sample more than we need
 | 
			
		||||
    sampled_indices = random.sample(range(len(dataset)),
 | 
			
		||||
                                    int(num_requests * 1.2))
 | 
			
		||||
    dataset = [dataset[i] for i in sampled_indices]
 | 
			
		||||
    # Shuffle the dataset.
 | 
			
		||||
    random.shuffle(dataset)
 | 
			
		||||
 | 
			
		||||
    # Tokenize the prompts and completions.
 | 
			
		||||
    prompts = [prompt for prompt, _ in dataset]
 | 
			
		||||
    prompt_token_ids = tokenizer(prompts).input_ids
 | 
			
		||||
    completions = [completion for _, completion in dataset]
 | 
			
		||||
    completion_token_ids = tokenizer(completions).input_ids
 | 
			
		||||
    tokenized_dataset = []
 | 
			
		||||
    for i in range(len(dataset)):
 | 
			
		||||
        output_len = len(completion_token_ids[i])
 | 
			
		||||
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
 | 
			
		||||
 | 
			
		||||
    # Filter out too long sequences.
 | 
			
		||||
    # Filter out sequences that are too long or too short
 | 
			
		||||
    filtered_dataset: List[Tuple[str, int, int]] = []
 | 
			
		||||
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
 | 
			
		||||
    for i in range(len(dataset)):
 | 
			
		||||
        if len(filtered_dataset) == num_requests:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        # Tokenize the prompts and completions.
 | 
			
		||||
        prompt = dataset[i][0]
 | 
			
		||||
        prompt_token_ids = tokenizer(prompt).input_ids
 | 
			
		||||
        completion = dataset[i][1]
 | 
			
		||||
        completion_token_ids = tokenizer(completion).input_ids
 | 
			
		||||
        prompt_len = len(prompt_token_ids)
 | 
			
		||||
        output_len = len(completion_token_ids
 | 
			
		||||
                         ) if fixed_output_len is None else fixed_output_len
 | 
			
		||||
        if prompt_len < 4 or output_len < 4:
 | 
			
		||||
            # Prune too short sequences.
 | 
			
		||||
            # This is because TGI causes errors when the input or output length
 | 
			
		||||
            # is too short.
 | 
			
		||||
            continue
 | 
			
		||||
        if prompt_len > 1024 or prompt_len + output_len > 2048:
 | 
			
		||||
            # Prune too long sequences.
 | 
			
		||||
            continue
 | 
			
		||||
        filtered_dataset.append((prompt, prompt_len, output_len))
 | 
			
		||||
 | 
			
		||||
    # Sample the requests.
 | 
			
		||||
    sampled_requests = random.sample(filtered_dataset, num_requests)
 | 
			
		||||
    return sampled_requests
 | 
			
		||||
    return filtered_dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sample_sonnet_requests(
 | 
			
		||||
@ -110,7 +115,9 @@ def sample_sonnet_requests(
 | 
			
		||||
    prefix_len: int,
 | 
			
		||||
    tokenizer: PreTrainedTokenizerBase,
 | 
			
		||||
) -> List[Tuple[str, str, int, int]]:
 | 
			
		||||
    assert input_len > prefix_len, "input_len must be greater than prefix_len."
 | 
			
		||||
    assert (
 | 
			
		||||
        input_len > prefix_len
 | 
			
		||||
    ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
 | 
			
		||||
 | 
			
		||||
    # Load the dataset.
 | 
			
		||||
    with open(dataset_path) as f:
 | 
			
		||||
@ -131,8 +138,9 @@ def sample_sonnet_requests(
 | 
			
		||||
        base_message, add_generation_prompt=True, tokenize=False)
 | 
			
		||||
    base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
 | 
			
		||||
 | 
			
		||||
    assert (input_len > base_prompt_offset
 | 
			
		||||
            ), f"Please set 'args.input-len' higher than {base_prompt_offset}."
 | 
			
		||||
    assert (
 | 
			
		||||
        input_len > base_prompt_offset
 | 
			
		||||
    ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
 | 
			
		||||
    num_input_lines = round(
 | 
			
		||||
        (input_len - base_prompt_offset) / average_poem_len)
 | 
			
		||||
 | 
			
		||||
@ -140,7 +148,7 @@ def sample_sonnet_requests(
 | 
			
		||||
    # prompt are fixed poem lines.
 | 
			
		||||
    assert (
 | 
			
		||||
        prefix_len > base_prompt_offset
 | 
			
		||||
    ), f"Please set 'args.prefix-len' higher than {base_prompt_offset}."
 | 
			
		||||
    ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
 | 
			
		||||
 | 
			
		||||
    num_prefix_lines = round(
 | 
			
		||||
        (prefix_len - base_prompt_offset) / average_poem_len)
 | 
			
		||||
@ -195,21 +203,34 @@ def calculate_metrics(
 | 
			
		||||
    actual_output_lens = []
 | 
			
		||||
    total_input = 0
 | 
			
		||||
    completed = 0
 | 
			
		||||
    itls = []
 | 
			
		||||
    tpots = []
 | 
			
		||||
    ttfts = []
 | 
			
		||||
    for i in range(len(outputs)):
 | 
			
		||||
        if outputs[i].success:
 | 
			
		||||
            output_len = len(tokenizer(outputs[i].generated_text).input_ids)
 | 
			
		||||
            # We use the tokenizer to count the number of output tokens for all
 | 
			
		||||
            # serving backends instead of looking at len(outputs[i].itl) since
 | 
			
		||||
            # multiple output tokens may be bundled together
 | 
			
		||||
            # Note: this may inflate the output token count slightly
 | 
			
		||||
            output_len = len(
 | 
			
		||||
                tokenizer(outputs[i].generated_text,
 | 
			
		||||
                          add_special_tokens=False).input_ids)
 | 
			
		||||
            actual_output_lens.append(output_len)
 | 
			
		||||
            total_input += input_requests[i][1]
 | 
			
		||||
            if output_len > 1:
 | 
			
		||||
                tpots.append(
 | 
			
		||||
                    (outputs[i].latency - outputs[i].ttft) / (output_len - 1))
 | 
			
		||||
            itls += outputs[i].itl
 | 
			
		||||
            ttfts.append(outputs[i].ttft)
 | 
			
		||||
            completed += 1
 | 
			
		||||
        else:
 | 
			
		||||
            actual_output_lens.append(0)
 | 
			
		||||
 | 
			
		||||
    if completed == 0:
 | 
			
		||||
        warnings.warn(
 | 
			
		||||
            "All requests failed. This is likely due to a misconfiguration "
 | 
			
		||||
            "on the benchmark arguments.",
 | 
			
		||||
            stacklevel=2)
 | 
			
		||||
    metrics = BenchmarkMetrics(
 | 
			
		||||
        completed=completed,
 | 
			
		||||
        total_input=total_input,
 | 
			
		||||
@ -221,9 +242,12 @@ def calculate_metrics(
 | 
			
		||||
        1000,  # ttfts is empty if streaming is not supported by backend
 | 
			
		||||
        median_ttft_ms=np.median(ttfts or 0) * 1000,
 | 
			
		||||
        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
 | 
			
		||||
        mean_tpot_ms=np.mean(tpots) * 1000,
 | 
			
		||||
        median_tpot_ms=np.median(tpots) * 1000,
 | 
			
		||||
        p99_tpot_ms=np.percentile(tpots, 99) * 1000,
 | 
			
		||||
        mean_tpot_ms=np.mean(tpots or 0) * 1000,
 | 
			
		||||
        median_tpot_ms=np.median(tpots or 0) * 1000,
 | 
			
		||||
        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
 | 
			
		||||
        mean_itl_ms=np.mean(itls or 0) * 1000,
 | 
			
		||||
        median_itl_ms=np.median(itls or 0) * 1000,
 | 
			
		||||
        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return metrics, actual_output_lens
 | 
			
		||||
@ -245,6 +269,24 @@ async def benchmark(
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown backend: {backend}")
 | 
			
		||||
 | 
			
		||||
    print("Starting initial single prompt test run...")
 | 
			
		||||
    test_prompt, test_prompt_len, test_output_len = input_requests[0]
 | 
			
		||||
    test_input = RequestFuncInput(
 | 
			
		||||
        model=model_id,
 | 
			
		||||
        prompt=test_prompt,
 | 
			
		||||
        api_url=api_url,
 | 
			
		||||
        prompt_len=test_prompt_len,
 | 
			
		||||
        output_len=test_output_len,
 | 
			
		||||
        best_of=best_of,
 | 
			
		||||
        use_beam_search=use_beam_search,
 | 
			
		||||
    )
 | 
			
		||||
    test_output = await request_func(request_func_input=test_input)
 | 
			
		||||
    if not test_output.success:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "Initial test run failed - Please make sure benchmark arguments "
 | 
			
		||||
            f"are correctly specified. Error: {test_output.error}")
 | 
			
		||||
    else:
 | 
			
		||||
        print("Initial test run completed. Starting main benchmark run...")
 | 
			
		||||
    print(f"Traffic request rate: {request_rate}")
 | 
			
		||||
 | 
			
		||||
    pbar = None if disable_tqdm else tqdm(total=len(input_requests))
 | 
			
		||||
@ -305,6 +347,10 @@ async def benchmark(
 | 
			
		||||
    print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
 | 
			
		||||
                                    metrics.median_tpot_ms))
 | 
			
		||||
    print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
 | 
			
		||||
    print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
 | 
			
		||||
    print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
 | 
			
		||||
    print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
 | 
			
		||||
    print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
 | 
			
		||||
    print("=" * 50)
 | 
			
		||||
 | 
			
		||||
    result = {
 | 
			
		||||
@ -321,6 +367,9 @@ async def benchmark(
 | 
			
		||||
        "mean_tpot_ms": metrics.mean_tpot_ms,
 | 
			
		||||
        "median_tpot_ms": metrics.median_tpot_ms,
 | 
			
		||||
        "p99_tpot_ms": metrics.p99_tpot_ms,
 | 
			
		||||
        "mean_itl_ms": metrics.mean_itl_ms,
 | 
			
		||||
        "median_itl_ms": metrics.median_itl_ms,
 | 
			
		||||
        "p99_itl_ms": metrics.p99_itl_ms,
 | 
			
		||||
        "input_lens": [output.prompt_len for output in outputs],
 | 
			
		||||
        "output_lens": actual_output_lens,
 | 
			
		||||
        "ttfts": [output.ttft for output in outputs],
 | 
			
		||||
@ -358,6 +407,7 @@ def main(args: argparse.Namespace):
 | 
			
		||||
            dataset_path=args.dataset,
 | 
			
		||||
            num_requests=args.num_prompts,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            fixed_output_len=args.sharegpt_output_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    elif args.dataset_name == "sharegpt":
 | 
			
		||||
@ -365,6 +415,7 @@ def main(args: argparse.Namespace):
 | 
			
		||||
            dataset_path=args.dataset_path,
 | 
			
		||||
            num_requests=args.num_prompts,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            fixed_output_len=args.sharegpt_output_len,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    elif args.dataset_name == "sonnet":
 | 
			
		||||
@ -373,9 +424,9 @@ def main(args: argparse.Namespace):
 | 
			
		||||
            input_requests = sample_sonnet_requests(
 | 
			
		||||
                dataset_path=args.dataset_path,
 | 
			
		||||
                num_requests=args.num_prompts,
 | 
			
		||||
                input_len=args.input_len,
 | 
			
		||||
                output_len=args.output_len,
 | 
			
		||||
                prefix_len=args.prefix_len,
 | 
			
		||||
                input_len=args.sonnet_input_len,
 | 
			
		||||
                output_len=args.sonnet_output_len,
 | 
			
		||||
                prefix_len=args.sonnet_prefix_len,
 | 
			
		||||
                tokenizer=tokenizer,
 | 
			
		||||
            )
 | 
			
		||||
            input_requests = [(prompt, prompt_len, output_len)
 | 
			
		||||
@ -388,9 +439,9 @@ def main(args: argparse.Namespace):
 | 
			
		||||
            input_requests = sample_sonnet_requests(
 | 
			
		||||
                dataset_path=args.dataset_path,
 | 
			
		||||
                num_requests=args.num_prompts,
 | 
			
		||||
                input_len=args.input_len,
 | 
			
		||||
                output_len=args.output_len,
 | 
			
		||||
                prefix_len=args.prefix_len,
 | 
			
		||||
                input_len=args.sonnet_input_len,
 | 
			
		||||
                output_len=args.sonnet_output_len,
 | 
			
		||||
                prefix_len=args.sonnet_prefix_len,
 | 
			
		||||
                tokenizer=tokenizer,
 | 
			
		||||
            )
 | 
			
		||||
            input_requests = [(prompt_formatted, prompt_len, output_len)
 | 
			
		||||
@ -521,6 +572,12 @@ if __name__ == "__main__":
 | 
			
		||||
        default=1000,
 | 
			
		||||
        help="Number of prompts to process.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--sharegpt-output-len",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help="Output length for each request. Overrides the output length "
 | 
			
		||||
        "from the ShareGPT dataset.")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--sonnet-input-len",
 | 
			
		||||
        type=int,
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,8 @@ from tqdm import tqdm
 | 
			
		||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
 | 
			
		||||
                          PreTrainedTokenizerBase)
 | 
			
		||||
 | 
			
		||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sample_requests(
 | 
			
		||||
    dataset_path: str,
 | 
			
		||||
@ -29,22 +31,23 @@ def sample_requests(
 | 
			
		||||
    dataset = [(data["conversations"][0]["value"],
 | 
			
		||||
                data["conversations"][1]["value"]) for data in dataset]
 | 
			
		||||
 | 
			
		||||
    # Tokenize the prompts and completions.
 | 
			
		||||
    prompts = [prompt for prompt, _ in dataset]
 | 
			
		||||
    prompt_token_ids = tokenizer(prompts).input_ids
 | 
			
		||||
    completions = [completion for _, completion in dataset]
 | 
			
		||||
    completion_token_ids = tokenizer(completions).input_ids
 | 
			
		||||
    tokenized_dataset = []
 | 
			
		||||
    for i in range(len(dataset)):
 | 
			
		||||
        output_len = len(completion_token_ids[i])
 | 
			
		||||
        if fixed_output_len is not None:
 | 
			
		||||
            output_len = fixed_output_len
 | 
			
		||||
        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
 | 
			
		||||
    # Shuffle the dataset.
 | 
			
		||||
    random.shuffle(dataset)
 | 
			
		||||
 | 
			
		||||
    # Filter out too long sequences.
 | 
			
		||||
    # Filter out sequences that are too long or too short
 | 
			
		||||
    filtered_dataset: List[Tuple[str, int, int]] = []
 | 
			
		||||
    for prompt, prompt_token_ids, output_len in tokenized_dataset:
 | 
			
		||||
    for i in range(len(dataset)):
 | 
			
		||||
        if len(filtered_dataset) == num_requests:
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        # Tokenize the prompts and completions.
 | 
			
		||||
        prompt = dataset[i][0]
 | 
			
		||||
        prompt_token_ids = tokenizer(prompt).input_ids
 | 
			
		||||
        completion = dataset[i][1]
 | 
			
		||||
        completion_token_ids = tokenizer(completion).input_ids
 | 
			
		||||
        prompt_len = len(prompt_token_ids)
 | 
			
		||||
        output_len = len(completion_token_ids
 | 
			
		||||
                         ) if fixed_output_len is None else fixed_output_len
 | 
			
		||||
        if prompt_len < 4 or output_len < 4:
 | 
			
		||||
            # Prune too short sequences.
 | 
			
		||||
            continue
 | 
			
		||||
@ -53,9 +56,7 @@ def sample_requests(
 | 
			
		||||
            continue
 | 
			
		||||
        filtered_dataset.append((prompt, prompt_len, output_len))
 | 
			
		||||
 | 
			
		||||
    # Sample the requests.
 | 
			
		||||
    sampled_requests = random.sample(filtered_dataset, num_requests)
 | 
			
		||||
    return sampled_requests
 | 
			
		||||
    return filtered_dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_vllm(
 | 
			
		||||
@ -72,47 +73,54 @@ def run_vllm(
 | 
			
		||||
    max_model_len: Optional[int],
 | 
			
		||||
    enforce_eager: bool,
 | 
			
		||||
    kv_cache_dtype: str,
 | 
			
		||||
    quantization_param_path: Optional[str],
 | 
			
		||||
    device: str,
 | 
			
		||||
    enable_prefix_caching: bool,
 | 
			
		||||
    enable_chunked_prefill: bool,
 | 
			
		||||
    max_num_batched_tokens: int,
 | 
			
		||||
    distributed_executor_backend: Optional[str],
 | 
			
		||||
    gpu_memory_utilization: float = 0.9,
 | 
			
		||||
    download_dir: Optional[str] = None,
 | 
			
		||||
) -> float:
 | 
			
		||||
    from vllm import LLM, SamplingParams
 | 
			
		||||
    llm = LLM(model=model,
 | 
			
		||||
              tokenizer=tokenizer,
 | 
			
		||||
              quantization=quantization,
 | 
			
		||||
              tensor_parallel_size=tensor_parallel_size,
 | 
			
		||||
              seed=seed,
 | 
			
		||||
              trust_remote_code=trust_remote_code,
 | 
			
		||||
              dtype=dtype,
 | 
			
		||||
              max_model_len=max_model_len,
 | 
			
		||||
              gpu_memory_utilization=gpu_memory_utilization,
 | 
			
		||||
              enforce_eager=enforce_eager,
 | 
			
		||||
              kv_cache_dtype=kv_cache_dtype,
 | 
			
		||||
              device=device,
 | 
			
		||||
              enable_prefix_caching=enable_prefix_caching,
 | 
			
		||||
              download_dir=download_dir)
 | 
			
		||||
    llm = LLM(
 | 
			
		||||
        model=model,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
        quantization=quantization,
 | 
			
		||||
        tensor_parallel_size=tensor_parallel_size,
 | 
			
		||||
        seed=seed,
 | 
			
		||||
        trust_remote_code=trust_remote_code,
 | 
			
		||||
        dtype=dtype,
 | 
			
		||||
        max_model_len=max_model_len,
 | 
			
		||||
        gpu_memory_utilization=gpu_memory_utilization,
 | 
			
		||||
        enforce_eager=enforce_eager,
 | 
			
		||||
        kv_cache_dtype=kv_cache_dtype,
 | 
			
		||||
        quantization_param_path=quantization_param_path,
 | 
			
		||||
        device=device,
 | 
			
		||||
        enable_prefix_caching=enable_prefix_caching,
 | 
			
		||||
        download_dir=download_dir,
 | 
			
		||||
        enable_chunked_prefill=enable_chunked_prefill,
 | 
			
		||||
        max_num_batched_tokens=max_num_batched_tokens,
 | 
			
		||||
        distributed_executor_backend=distributed_executor_backend,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Add the requests to the engine.
 | 
			
		||||
    prompts = []
 | 
			
		||||
    sampling_params = []
 | 
			
		||||
    for prompt, _, output_len in requests:
 | 
			
		||||
        sampling_params = SamplingParams(
 | 
			
		||||
            n=n,
 | 
			
		||||
            temperature=0.0 if use_beam_search else 1.0,
 | 
			
		||||
            top_p=1.0,
 | 
			
		||||
            use_beam_search=use_beam_search,
 | 
			
		||||
            ignore_eos=True,
 | 
			
		||||
            max_tokens=output_len,
 | 
			
		||||
        )
 | 
			
		||||
        # FIXME(woosuk): Do not use internal method.
 | 
			
		||||
        llm._add_request(
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            prompt_token_ids=None,
 | 
			
		||||
            sampling_params=sampling_params,
 | 
			
		||||
        )
 | 
			
		||||
        prompts.append(prompt)
 | 
			
		||||
        sampling_params.append(
 | 
			
		||||
            SamplingParams(
 | 
			
		||||
                n=n,
 | 
			
		||||
                temperature=0.0 if use_beam_search else 1.0,
 | 
			
		||||
                top_p=1.0,
 | 
			
		||||
                use_beam_search=use_beam_search,
 | 
			
		||||
                ignore_eos=True,
 | 
			
		||||
                max_tokens=output_len,
 | 
			
		||||
            ))
 | 
			
		||||
 | 
			
		||||
    start = time.perf_counter()
 | 
			
		||||
    # FIXME(woosuk): Do not use internal method.
 | 
			
		||||
    llm._run_engine(use_tqdm=True)
 | 
			
		||||
    llm.generate(prompts, sampling_params, use_tqdm=True)
 | 
			
		||||
    end = time.perf_counter()
 | 
			
		||||
    return end - start
 | 
			
		||||
 | 
			
		||||
@ -212,14 +220,15 @@ def main(args: argparse.Namespace):
 | 
			
		||||
                                   args.output_len)
 | 
			
		||||
 | 
			
		||||
    if args.backend == "vllm":
 | 
			
		||||
        elapsed_time = run_vllm(requests, args.model, args.tokenizer,
 | 
			
		||||
                                args.quantization, args.tensor_parallel_size,
 | 
			
		||||
                                args.seed, args.n, args.use_beam_search,
 | 
			
		||||
                                args.trust_remote_code, args.dtype,
 | 
			
		||||
                                args.max_model_len, args.enforce_eager,
 | 
			
		||||
                                args.kv_cache_dtype, args.device,
 | 
			
		||||
                                args.enable_prefix_caching,
 | 
			
		||||
                                args.gpu_memory_utilization, args.download_dir)
 | 
			
		||||
        elapsed_time = run_vllm(
 | 
			
		||||
            requests, args.model, args.tokenizer, args.quantization,
 | 
			
		||||
            args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
 | 
			
		||||
            args.trust_remote_code, args.dtype, args.max_model_len,
 | 
			
		||||
            args.enforce_eager, args.kv_cache_dtype,
 | 
			
		||||
            args.quantization_param_path, args.device,
 | 
			
		||||
            args.enable_prefix_caching, args.enable_chunked_prefill,
 | 
			
		||||
            args.max_num_batched_tokens, args.distributed_executor_backend,
 | 
			
		||||
            args.gpu_memory_utilization, args.download_dir)
 | 
			
		||||
    elif args.backend == "hf":
 | 
			
		||||
        assert args.tensor_parallel_size == 1
 | 
			
		||||
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
 | 
			
		||||
@ -235,6 +244,18 @@ def main(args: argparse.Namespace):
 | 
			
		||||
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
 | 
			
		||||
          f"{total_num_tokens / elapsed_time:.2f} tokens/s")
 | 
			
		||||
 | 
			
		||||
    # Output JSON results if specified
 | 
			
		||||
    if args.output_json:
 | 
			
		||||
        results = {
 | 
			
		||||
            "elapsed_time": elapsed_time,
 | 
			
		||||
            "num_requests": len(requests),
 | 
			
		||||
            "total_num_tokens": total_num_tokens,
 | 
			
		||||
            "requests_per_second": len(requests) / elapsed_time,
 | 
			
		||||
            "tokens_per_second": total_num_tokens / elapsed_time,
 | 
			
		||||
        }
 | 
			
		||||
        with open(args.output_json, "w") as f:
 | 
			
		||||
            json.dump(results, f, indent=4)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(description="Benchmark the throughput.")
 | 
			
		||||
@ -259,7 +280,7 @@ if __name__ == "__main__":
 | 
			
		||||
    parser.add_argument("--tokenizer", type=str, default=None)
 | 
			
		||||
    parser.add_argument('--quantization',
 | 
			
		||||
                        '-q',
 | 
			
		||||
                        choices=['awq', 'gptq', 'squeezellm', None],
 | 
			
		||||
                        choices=[*QUANTIZATION_METHODS, None],
 | 
			
		||||
                        default=None)
 | 
			
		||||
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--n",
 | 
			
		||||
@ -304,27 +325,58 @@ if __name__ == "__main__":
 | 
			
		||||
                        action="store_true",
 | 
			
		||||
                        help="enforce eager execution")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--kv-cache-dtype",
 | 
			
		||||
        '--kv-cache-dtype',
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=["auto", "fp8_e5m2"],
 | 
			
		||||
        choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
 | 
			
		||||
        default="auto",
 | 
			
		||||
        help=
 | 
			
		||||
        'Data type for kv cache storage. If "auto", will use model data type.')
 | 
			
		||||
        help='Data type for kv cache storage. If "auto", will use model '
 | 
			
		||||
        'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
 | 
			
		||||
        'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--quantization-param-path',
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help='Path to the JSON file containing the KV cache scaling factors. '
 | 
			
		||||
        'This should generally be supplied, when KV cache dtype is FP8. '
 | 
			
		||||
        'Otherwise, KV cache scaling factors default to 1.0, which may cause '
 | 
			
		||||
        'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
 | 
			
		||||
        'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
 | 
			
		||||
        'instead supported for common inference criteria.')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--device",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="cuda",
 | 
			
		||||
        choices=["cuda"],
 | 
			
		||||
        help='device type for vLLM execution, supporting CUDA only currently.')
 | 
			
		||||
        choices=["cuda", "cpu"],
 | 
			
		||||
        help='device type for vLLM execution, supporting CUDA and CPU.')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--enable-prefix-caching",
 | 
			
		||||
        action='store_true',
 | 
			
		||||
        help="enable automatic prefix caching for vLLM backend.")
 | 
			
		||||
    parser.add_argument("--enable-chunked-prefill",
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
                        help="enable chunked prefill for vLLM backend.")
 | 
			
		||||
    parser.add_argument('--max-num-batched-tokens',
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        default=None,
 | 
			
		||||
                        help='maximum number of batched tokens per '
 | 
			
		||||
                        'iteration')
 | 
			
		||||
    parser.add_argument('--download-dir',
 | 
			
		||||
                        type=str,
 | 
			
		||||
                        default=None,
 | 
			
		||||
                        help='directory to download and load the weights, '
 | 
			
		||||
                        'default to the default cache dir of huggingface')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--output-json',
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help='Path to save the throughput results in JSON format.')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--distributed-executor-backend',
 | 
			
		||||
        choices=['ray', 'mp'],
 | 
			
		||||
        default=None,
 | 
			
		||||
        help='Backend to use for distributed serving. When more than 1 GPU '
 | 
			
		||||
        'is used, will be automatically set to "ray" if installed '
 | 
			
		||||
        'or "mp" (multiprocessing) otherwise.')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    if args.tokenizer is None:
 | 
			
		||||
        args.tokenizer = args.model
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										352
									
								
								benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										352
									
								
								benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,352 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import copy
 | 
			
		||||
import itertools
 | 
			
		||||
import pickle as pkl
 | 
			
		||||
import time
 | 
			
		||||
from typing import Callable, Iterable, List, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils.benchmark as TBenchmark
 | 
			
		||||
from torch.utils.benchmark import Measurement as TMeasurement
 | 
			
		||||
from weight_shapes import WEIGHT_SHAPES
 | 
			
		||||
 | 
			
		||||
from vllm import _custom_ops as ops
 | 
			
		||||
 | 
			
		||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
 | 
			
		||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
 | 
			
		||||
DEFAULT_TP_SIZES = [1]
 | 
			
		||||
 | 
			
		||||
# helpers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def to_fp8(tensor: torch.tensor) -> torch.tensor:
 | 
			
		||||
    finfo = torch.finfo(torch.float8_e4m3fn)
 | 
			
		||||
    return torch.round(tensor.clamp(
 | 
			
		||||
        min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def to_int8(tensor: torch.tensor) -> torch.tensor:
 | 
			
		||||
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
 | 
			
		||||
                      k: int) -> Tuple[torch.tensor, torch.tensor]:
 | 
			
		||||
 | 
			
		||||
    a = torch.randn((m, k), device='cuda') * 5
 | 
			
		||||
    b = torch.randn((n, k), device='cuda').t() * 5
 | 
			
		||||
 | 
			
		||||
    if dtype == torch.int8:
 | 
			
		||||
        return to_int8(a), to_int8(b)
 | 
			
		||||
    if dtype == torch.float8_e4m3fn:
 | 
			
		||||
        return to_fp8(a), to_fp8(b)
 | 
			
		||||
 | 
			
		||||
    raise ValueError("unsupported dtype")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# impl
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
 | 
			
		||||
                    scale_b: torch.tensor,
 | 
			
		||||
                    out_dtype: torch.dtype) -> torch.tensor:
 | 
			
		||||
    return torch.mm(a, b)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
 | 
			
		||||
                     scale_b: torch.tensor,
 | 
			
		||||
                     out_dtype: torch.dtype) -> torch.tensor:
 | 
			
		||||
    return torch._scaled_mm(a,
 | 
			
		||||
                            b,
 | 
			
		||||
                            scale_a=scale_a,
 | 
			
		||||
                            scale_b=scale_b,
 | 
			
		||||
                            out_dtype=out_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
 | 
			
		||||
                                scale_a: torch.tensor, scale_b: torch.tensor,
 | 
			
		||||
                                out_dtype: torch.dtype) -> torch.tensor:
 | 
			
		||||
    return torch._scaled_mm(a,
 | 
			
		||||
                            b,
 | 
			
		||||
                            scale_a=scale_a,
 | 
			
		||||
                            scale_b=scale_b,
 | 
			
		||||
                            out_dtype=out_dtype,
 | 
			
		||||
                            use_fast_accum=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
 | 
			
		||||
                 scale_b: torch.tensor,
 | 
			
		||||
                 out_dtype: torch.dtype) -> torch.tensor:
 | 
			
		||||
    return ops.cutlass_scaled_mm_dq(a,
 | 
			
		||||
                                    b,
 | 
			
		||||
                                    scale_a,
 | 
			
		||||
                                    scale_b,
 | 
			
		||||
                                    out_dtype=out_dtype)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# bench
 | 
			
		||||
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
 | 
			
		||||
             scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
 | 
			
		||||
             sub_label: str, fn: Callable, description: str) -> TMeasurement:
 | 
			
		||||
 | 
			
		||||
    min_run_time = 1
 | 
			
		||||
 | 
			
		||||
    globals = {
 | 
			
		||||
        "a": a,
 | 
			
		||||
        "b": b,
 | 
			
		||||
        "scale_a": scale_a,
 | 
			
		||||
        "scale_b": scale_b,
 | 
			
		||||
        "out_dtype": out_dtype,
 | 
			
		||||
        "fn": fn,
 | 
			
		||||
    }
 | 
			
		||||
    return TBenchmark.Timer(
 | 
			
		||||
        stmt="fn(a, b, scale_a, scale_b, out_dtype)",
 | 
			
		||||
        globals=globals,
 | 
			
		||||
        label=label,
 | 
			
		||||
        sub_label=sub_label,
 | 
			
		||||
        description=description,
 | 
			
		||||
    ).blocked_autorange(min_run_time=min_run_time)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
 | 
			
		||||
               sub_label: str) -> Iterable[TMeasurement]:
 | 
			
		||||
    assert dtype == torch.int8
 | 
			
		||||
    a, b = make_rand_tensors(torch.int8, m, n, k)
 | 
			
		||||
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
 | 
			
		||||
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    timers = []
 | 
			
		||||
    # pytorch impl
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
 | 
			
		||||
                 b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
 | 
			
		||||
                 torch.bfloat16, label, sub_label, pytorch_i8_impl,
 | 
			
		||||
                 "pytorch_bf16_bf16_bf16_matmul-no-scales"))
 | 
			
		||||
 | 
			
		||||
    # cutlass impl
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
 | 
			
		||||
                 torch.bfloat16, label, sub_label, cutlass_impl,
 | 
			
		||||
                 "cutlass_i8_i8_bf16_scaled_mm"))
 | 
			
		||||
 | 
			
		||||
    return timers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
 | 
			
		||||
              sub_label: str) -> Iterable[TMeasurement]:
 | 
			
		||||
    assert dtype == torch.float8_e4m3fn
 | 
			
		||||
    a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
 | 
			
		||||
    scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
 | 
			
		||||
    scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    timers = []
 | 
			
		||||
 | 
			
		||||
    # pytorch impl: bf16 output, without fp8 fast accum
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
 | 
			
		||||
                 pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
 | 
			
		||||
 | 
			
		||||
    # pytorch impl: bf16 output, with fp8 fast accum
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
 | 
			
		||||
                 pytorch_fp8_impl_fast_accum,
 | 
			
		||||
                 "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
 | 
			
		||||
 | 
			
		||||
    # pytorch impl: fp16 output, without fp8 fast accum
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
 | 
			
		||||
                 pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
 | 
			
		||||
 | 
			
		||||
    # pytorch impl: fp16 output, with fp8 fast accum
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
 | 
			
		||||
                 pytorch_fp8_impl_fast_accum,
 | 
			
		||||
                 "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
 | 
			
		||||
 | 
			
		||||
    # cutlass impl: bf16 output
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
 | 
			
		||||
                 torch.bfloat16, label, sub_label, cutlass_impl,
 | 
			
		||||
                 "cutlass_fp8_fp8_bf16_scaled_mm"))
 | 
			
		||||
    # cutlass impl: fp16 output
 | 
			
		||||
    timers.append(
 | 
			
		||||
        bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
 | 
			
		||||
                 torch.float16, label, sub_label, cutlass_impl,
 | 
			
		||||
                 "cutlass_fp8_fp8_fp16_scaled_mm"))
 | 
			
		||||
    return timers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
 | 
			
		||||
          sub_label: str) -> Iterable[TMeasurement]:
 | 
			
		||||
    if dtype == torch.int8:
 | 
			
		||||
        return bench_int8(dtype, m, k, n, label, sub_label)
 | 
			
		||||
    if dtype == torch.float8_e4m3fn:
 | 
			
		||||
        return bench_fp8(dtype, m, k, n, label, sub_label)
 | 
			
		||||
    raise ValueError("unsupported type")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# runner
 | 
			
		||||
def print_timers(timers: Iterable[TMeasurement]):
 | 
			
		||||
    compare = TBenchmark.Compare(timers)
 | 
			
		||||
    compare.print()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run(dtype: torch.dtype,
 | 
			
		||||
        MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
 | 
			
		||||
 | 
			
		||||
    results = []
 | 
			
		||||
    for m, k, n in MKNs:
 | 
			
		||||
        timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
 | 
			
		||||
                       f"MKN=({m}x{k}x{n})")
 | 
			
		||||
        print_timers(timers)
 | 
			
		||||
        results.extend(timers)
 | 
			
		||||
 | 
			
		||||
    return results
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# output makers
 | 
			
		||||
def make_output(data: Iterable[TMeasurement],
 | 
			
		||||
                MKNs: Iterable[Tuple[int, int, int]],
 | 
			
		||||
                base_description: str,
 | 
			
		||||
                timestamp=None):
 | 
			
		||||
 | 
			
		||||
    print(f"== All Results {base_description} ====")
 | 
			
		||||
    print_timers(data)
 | 
			
		||||
 | 
			
		||||
    # pickle all the results
 | 
			
		||||
    timestamp = int(time.time()) if timestamp is None else timestamp
 | 
			
		||||
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
 | 
			
		||||
        pkl.dump(data, f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# argparse runners
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_square_bench(args):
 | 
			
		||||
    dim_sizes = list(
 | 
			
		||||
        range(args.dim_start, args.dim_end + 1, args.dim_increment))
 | 
			
		||||
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
 | 
			
		||||
    data = run(args.dtype, MKNs)
 | 
			
		||||
 | 
			
		||||
    make_output(data, MKNs, f"square_bench-{args.dtype}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_range_bench(args):
 | 
			
		||||
    dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
 | 
			
		||||
    n = len(dim_sizes)
 | 
			
		||||
    Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
 | 
			
		||||
    Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
 | 
			
		||||
    Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
 | 
			
		||||
    MKNs = list(zip(Ms, Ks, Ns))
 | 
			
		||||
    data = run(args.dtype, MKNs)
 | 
			
		||||
 | 
			
		||||
    make_output(data, MKNs, f"range_bench-{args.dtype}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_model_bench(args):
 | 
			
		||||
 | 
			
		||||
    print("Benchmarking models:")
 | 
			
		||||
    for i, model in enumerate(args.models):
 | 
			
		||||
        print(f"[{i}]  {model}")
 | 
			
		||||
 | 
			
		||||
    def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
 | 
			
		||||
        KNs = []
 | 
			
		||||
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
 | 
			
		||||
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
 | 
			
		||||
            KNs.append(KN)
 | 
			
		||||
        return KNs
 | 
			
		||||
 | 
			
		||||
    model_bench_data = []
 | 
			
		||||
    models_tps = list(itertools.product(args.models, args.tp_sizes))
 | 
			
		||||
    for model, tp_size in models_tps:
 | 
			
		||||
        Ms = args.batch_sizes
 | 
			
		||||
        KNs = model_shapes(model, tp_size)
 | 
			
		||||
        MKNs = []
 | 
			
		||||
        for m in Ms:
 | 
			
		||||
            for k, n in KNs:
 | 
			
		||||
                MKNs.append((m, k, n))
 | 
			
		||||
 | 
			
		||||
        data = run(args.dtype, MKNs)
 | 
			
		||||
        model_bench_data.append(data)
 | 
			
		||||
 | 
			
		||||
    # Print all results
 | 
			
		||||
    for data, model_tp in zip(model_bench_data, models_tps):
 | 
			
		||||
        model, tp_size = model_tp
 | 
			
		||||
        print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
 | 
			
		||||
        print_timers(data)
 | 
			
		||||
 | 
			
		||||
    timestamp = int(time.time())
 | 
			
		||||
 | 
			
		||||
    all_data = []
 | 
			
		||||
    for d in model_bench_data:
 | 
			
		||||
        all_data.extend(d)
 | 
			
		||||
    # pickle all data
 | 
			
		||||
    with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
 | 
			
		||||
        pkl.dump(all_data, f)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
    def to_torch_dtype(dt):
 | 
			
		||||
        if dt == "int8":
 | 
			
		||||
            return torch.int8
 | 
			
		||||
        if dt == "fp8":
 | 
			
		||||
            return torch.float8_e4m3fn
 | 
			
		||||
        raise ValueError("unsupported dtype")
 | 
			
		||||
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="""
 | 
			
		||||
Benchmark Cutlass GEMM.
 | 
			
		||||
 | 
			
		||||
    To run square GEMMs:
 | 
			
		||||
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
 | 
			
		||||
    
 | 
			
		||||
    To run constant N and K and sweep M:
 | 
			
		||||
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
 | 
			
		||||
    
 | 
			
		||||
    To run dimensions from a model:
 | 
			
		||||
        python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
 | 
			
		||||
    
 | 
			
		||||
    Output:
 | 
			
		||||
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
 | 
			
		||||
            """,  # noqa: E501
 | 
			
		||||
        formatter_class=argparse.RawTextHelpFormatter)
 | 
			
		||||
 | 
			
		||||
    parser.add_argument("--dtype",
 | 
			
		||||
                        type=to_torch_dtype,
 | 
			
		||||
                        required=True,
 | 
			
		||||
                        help="Available options are ['int8', 'fp8']")
 | 
			
		||||
    subparsers = parser.add_subparsers(dest="cmd")
 | 
			
		||||
 | 
			
		||||
    square_parser = subparsers.add_parser("square_bench")
 | 
			
		||||
    square_parser.add_argument("--dim-start", type=int, required=True)
 | 
			
		||||
    square_parser.add_argument("--dim-end", type=int, required=True)
 | 
			
		||||
    square_parser.add_argument("--dim-increment", type=int, required=True)
 | 
			
		||||
    square_parser.set_defaults(func=run_square_bench)
 | 
			
		||||
 | 
			
		||||
    range_parser = subparsers.add_parser("range_bench")
 | 
			
		||||
    range_parser.add_argument("--dim-start", type=int, required=True)
 | 
			
		||||
    range_parser.add_argument("--dim-end", type=int, required=True)
 | 
			
		||||
    range_parser.add_argument("--dim-increment", type=int, required=True)
 | 
			
		||||
    range_parser.add_argument("--m-constant", type=int, default=None)
 | 
			
		||||
    range_parser.add_argument("--n-constant", type=int, default=None)
 | 
			
		||||
    range_parser.add_argument("--k-constant", type=int, default=None)
 | 
			
		||||
    range_parser.set_defaults(func=run_range_bench)
 | 
			
		||||
 | 
			
		||||
    model_parser = subparsers.add_parser("model_bench")
 | 
			
		||||
    model_parser.add_argument("--models",
 | 
			
		||||
                              nargs="+",
 | 
			
		||||
                              type=str,
 | 
			
		||||
                              default=DEFAULT_MODELS,
 | 
			
		||||
                              choices=WEIGHT_SHAPES.keys())
 | 
			
		||||
    model_parser.add_argument("--tp-sizes",
 | 
			
		||||
                              nargs="+",
 | 
			
		||||
                              type=int,
 | 
			
		||||
                              default=DEFAULT_TP_SIZES)
 | 
			
		||||
    model_parser.add_argument("--batch-sizes",
 | 
			
		||||
                              nargs="+",
 | 
			
		||||
                              type=int,
 | 
			
		||||
                              default=DEFAULT_BATCH_SIZES)
 | 
			
		||||
    model_parser.set_defaults(func=run_model_bench)
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    args.func(args)
 | 
			
		||||
							
								
								
									
										37
									
								
								benchmarks/cutlass_benchmarks/weight_shapes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								benchmarks/cutlass_benchmarks/weight_shapes.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
# Weight Shapes are in the format
 | 
			
		||||
# ([K, N], TP_SPLIT_DIM)
 | 
			
		||||
# Example:
 | 
			
		||||
#  A shape of ([14336, 4096], 0) indicates the following GEMM shape,
 | 
			
		||||
#   - TP1 : K = 14336, N = 4096
 | 
			
		||||
#   - TP2 : K = 7168, N = 4096
 | 
			
		||||
#  A shape of ([4096, 6144], 1) indicates the following GEMM shape,
 | 
			
		||||
#   - TP1 : K = 4096, N = 6144
 | 
			
		||||
#   - TP4 : K = 4096, N = 1536
 | 
			
		||||
 | 
			
		||||
# TP1 shapes
 | 
			
		||||
WEIGHT_SHAPES = {
 | 
			
		||||
    "mistralai/Mistral-7B-v0.1": [
 | 
			
		||||
        ([4096, 6144], 1),
 | 
			
		||||
        ([4096, 4096], 0),
 | 
			
		||||
        ([4096, 28672], 1),
 | 
			
		||||
        ([14336, 4096], 0),
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-7b-hf": [
 | 
			
		||||
        ([4096, 12288], 1),
 | 
			
		||||
        ([4096, 4096], 0),
 | 
			
		||||
        ([4096, 22016], 1),
 | 
			
		||||
        ([11008, 4096], 0),
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-13b-hf": [
 | 
			
		||||
        ([5120, 15360], 1),
 | 
			
		||||
        ([5120, 5120], 0),
 | 
			
		||||
        ([5120, 27648], 1),
 | 
			
		||||
        ([13824, 5120], 0),
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-70b-hf": [
 | 
			
		||||
        ([8192, 10240], 1),
 | 
			
		||||
        ([8192, 8192], 0),
 | 
			
		||||
        ([8192, 57344], 1),
 | 
			
		||||
        ([28672, 8192], 0),
 | 
			
		||||
    ],
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										302
									
								
								benchmarks/kernels/benchmark_aqlm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										302
									
								
								benchmarks/kernels/benchmark_aqlm.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,302 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from vllm import _custom_ops as ops
 | 
			
		||||
from vllm.model_executor.layers.quantization.aqlm import (
 | 
			
		||||
    dequantize_weight, generic_dequantize_gemm, get_int_dtype,
 | 
			
		||||
    optimized_dequantize_gemm)
 | 
			
		||||
 | 
			
		||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def torch_mult(
 | 
			
		||||
        input: torch.Tensor,  #  [..., in_features]
 | 
			
		||||
        weights: torch.Tensor,
 | 
			
		||||
        scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    output = F.linear(input, weights)
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dequant_out_scale(
 | 
			
		||||
    input: torch.Tensor,  #  [..., in_features]
 | 
			
		||||
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
 | 
			
		||||
    codebooks: torch.
 | 
			
		||||
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
 | 
			
		||||
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
 | 
			
		||||
    output_partition_sizes: torch.IntTensor,
 | 
			
		||||
    bias: Optional[torch.Tensor],
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
    weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
 | 
			
		||||
 | 
			
		||||
    if bias is None:
 | 
			
		||||
        output = F.linear(input, weights, bias)
 | 
			
		||||
        orig_shape = output.shape
 | 
			
		||||
        flattened_output = output.view(-1, output.size(-1))
 | 
			
		||||
        f_scales = scales.view(-1, scales.shape[0])
 | 
			
		||||
        b_scales = f_scales.expand(flattened_output.shape[0], -1)
 | 
			
		||||
        flattened_output *= b_scales
 | 
			
		||||
        return flattened_output.view(orig_shape)
 | 
			
		||||
    else:
 | 
			
		||||
        b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
 | 
			
		||||
            -1, weights.shape[1])
 | 
			
		||||
        weights *= b_scales
 | 
			
		||||
        return F.linear(input, weights, bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dequant_weight_scale(
 | 
			
		||||
    input: torch.Tensor,  #  [..., in_features]
 | 
			
		||||
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
 | 
			
		||||
    codebooks: torch.
 | 
			
		||||
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
 | 
			
		||||
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
 | 
			
		||||
    output_partition_sizes: torch.IntTensor,
 | 
			
		||||
    bias: Optional[torch.Tensor],
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
    weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
 | 
			
		||||
 | 
			
		||||
    b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
 | 
			
		||||
        -1, weights.shape[1])
 | 
			
		||||
    weights *= b_scales
 | 
			
		||||
    return F.linear(input, weights, bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dequant_no_scale(
 | 
			
		||||
    input: torch.Tensor,  #  [..., in_features]
 | 
			
		||||
    codes: torch.IntTensor,  #  [num_out_groups, num_in_groups, num_codebooks]
 | 
			
		||||
    codebooks: torch.
 | 
			
		||||
    Tensor,  #  [num_codebooks, codebook_size, out_group_size, in_group_size]
 | 
			
		||||
    scales: torch.Tensor,  #  [num_out_groups, 1, 1, 1]
 | 
			
		||||
    output_partition_sizes: torch.IntTensor,
 | 
			
		||||
    bias: Optional[torch.Tensor],
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
    weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
 | 
			
		||||
 | 
			
		||||
    return F.linear(input, weights, bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
 | 
			
		||||
# the generic pytorch version.
 | 
			
		||||
# Just visual comparison.
 | 
			
		||||
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
 | 
			
		||||
 | 
			
		||||
    n = parts.sum().item()
 | 
			
		||||
 | 
			
		||||
    device = torch.device('cuda:0')
 | 
			
		||||
 | 
			
		||||
    code_range = (1 << bits) // 2
 | 
			
		||||
    ingroups = 8
 | 
			
		||||
 | 
			
		||||
    codes = torch.randint(-code_range,
 | 
			
		||||
                          code_range,
 | 
			
		||||
                          size=(n, k // ingroups, nbooks),
 | 
			
		||||
                          dtype=get_int_dtype(bits),
 | 
			
		||||
                          device=device)
 | 
			
		||||
 | 
			
		||||
    codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
 | 
			
		||||
                            dtype=torch.float16,
 | 
			
		||||
                            device=device)
 | 
			
		||||
 | 
			
		||||
    count = 0
 | 
			
		||||
    for index in range(16):
 | 
			
		||||
        for i in range(8):
 | 
			
		||||
            for book in range(nbooks):
 | 
			
		||||
                codebooks[book, index, 0, i] = count * (10**book)
 | 
			
		||||
            count += 1
 | 
			
		||||
 | 
			
		||||
    print("codes shape", codes.shape)
 | 
			
		||||
 | 
			
		||||
    for i in range(16):
 | 
			
		||||
        for book in range(nbooks):
 | 
			
		||||
            codes[0, i, book] = i
 | 
			
		||||
            codes[0, -i, book] = i
 | 
			
		||||
 | 
			
		||||
    weights = dequantize_weight(codes, codebooks, None)
 | 
			
		||||
    weights2 = ops.aqlm_dequant(codes, codebooks, parts)
 | 
			
		||||
 | 
			
		||||
    print("weights shape:", weights.shape)
 | 
			
		||||
    print("weights2 shape:", weights2.shape)
 | 
			
		||||
 | 
			
		||||
    print("weights are:", weights)
 | 
			
		||||
    print("weights2 are:", weights2)
 | 
			
		||||
 | 
			
		||||
    print("first 128 weights are", weights[0, 0:128].to(torch.int32))
 | 
			
		||||
    print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
 | 
			
		||||
 | 
			
		||||
    print("last 128 weights are", weights[0, -128:])
 | 
			
		||||
    print("last 128 weights2 are:", weights2[0, -128:])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 | 
			
		||||
    parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
 | 
			
		||||
 | 
			
		||||
    # Add arguments
 | 
			
		||||
    parser.add_argument("--nbooks",
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        default=1,
 | 
			
		||||
                        help="Number of codebooks (default: 1)")
 | 
			
		||||
    parser.add_argument("--bits",
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        default=16,
 | 
			
		||||
                        help="Number of bits per code element (default: 16)")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--test",
 | 
			
		||||
        type=bool,
 | 
			
		||||
        default=False,
 | 
			
		||||
        help="Run the decompression/dequant tester rather than benchmarking "
 | 
			
		||||
        "(default: False)")
 | 
			
		||||
 | 
			
		||||
    # Parse the arguments
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # Extract values
 | 
			
		||||
    nbooks = args.nbooks
 | 
			
		||||
    bits = args.bits
 | 
			
		||||
 | 
			
		||||
    if args.test:
 | 
			
		||||
        dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # Otherwise, benchmark.
 | 
			
		||||
    methods = [
 | 
			
		||||
        ops.aqlm_gemm,
 | 
			
		||||
        dequant_out_scale,
 | 
			
		||||
        generic_dequantize_gemm,
 | 
			
		||||
        optimized_dequantize_gemm,
 | 
			
		||||
        dequant_weight_scale,
 | 
			
		||||
        torch_mult,
 | 
			
		||||
        dequant_no_scale,
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
 | 
			
		||||
    print(f"writing benchmarks to file {filename}")
 | 
			
		||||
    with open(filename, "w") as f:
 | 
			
		||||
        sys.stdout = f
 | 
			
		||||
 | 
			
		||||
        print('m | k | n | n parts', end='')
 | 
			
		||||
        for method in methods:
 | 
			
		||||
            print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
 | 
			
		||||
        print('')
 | 
			
		||||
 | 
			
		||||
        # These are reasonable prefill sizes.
 | 
			
		||||
        ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
 | 
			
		||||
                         (4096, (11008, 11008)), (11008, (4096, )))
 | 
			
		||||
 | 
			
		||||
        # reasonable ranges for m.
 | 
			
		||||
        for m in [
 | 
			
		||||
                1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
 | 
			
		||||
                128, 256, 512, 1024, 1536, 2048, 3072, 4096
 | 
			
		||||
        ]:
 | 
			
		||||
            print(f'{m}', file=sys.__stdout__)
 | 
			
		||||
            for ksp in ksandpartions:
 | 
			
		||||
                run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
 | 
			
		||||
                         methods)
 | 
			
		||||
 | 
			
		||||
        sys.stdout = sys.__stdout__
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
 | 
			
		||||
             methods):
 | 
			
		||||
 | 
			
		||||
    # I didn't see visible improvements from increasing these, but feel free :)
 | 
			
		||||
    num_warmup_trials = 1
 | 
			
		||||
    num_trials = 1
 | 
			
		||||
 | 
			
		||||
    num_calls = 100
 | 
			
		||||
 | 
			
		||||
    # warmup.
 | 
			
		||||
    for method in methods:
 | 
			
		||||
        for _ in range(num_warmup_trials):
 | 
			
		||||
            run_timing(
 | 
			
		||||
                num_calls=num_calls,
 | 
			
		||||
                m=m,
 | 
			
		||||
                k=k,
 | 
			
		||||
                parts=parts,
 | 
			
		||||
                nbooks=nbooks,
 | 
			
		||||
                bits=bits,
 | 
			
		||||
                method=method,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    n = parts.sum().item()
 | 
			
		||||
    print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
 | 
			
		||||
 | 
			
		||||
    for method in methods:
 | 
			
		||||
        best_time_us = 1e20
 | 
			
		||||
        for _ in range(num_trials):
 | 
			
		||||
            kernel_dur_ms = run_timing(
 | 
			
		||||
                num_calls=num_calls,
 | 
			
		||||
                m=m,
 | 
			
		||||
                k=k,
 | 
			
		||||
                parts=parts,
 | 
			
		||||
                nbooks=nbooks,
 | 
			
		||||
                bits=bits,
 | 
			
		||||
                method=method,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            kernel_dur_us = 1000 * kernel_dur_ms
 | 
			
		||||
 | 
			
		||||
            if kernel_dur_us < best_time_us:
 | 
			
		||||
                best_time_us = kernel_dur_us
 | 
			
		||||
 | 
			
		||||
        print(f' | {kernel_dur_us:.0f}', end='')
 | 
			
		||||
 | 
			
		||||
    print('')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
 | 
			
		||||
               nbooks: int, bits: int, method) -> float:
 | 
			
		||||
 | 
			
		||||
    n = parts.sum().item()
 | 
			
		||||
 | 
			
		||||
    device = torch.device('cuda:0')
 | 
			
		||||
 | 
			
		||||
    input = torch.randn((1, m, k), dtype=torch.float16, device=device)
 | 
			
		||||
 | 
			
		||||
    code_range = (1 << bits) // 2
 | 
			
		||||
    ingroups = 8
 | 
			
		||||
 | 
			
		||||
    codes = torch.randint(-code_range,
 | 
			
		||||
                          code_range,
 | 
			
		||||
                          size=(n, k // ingroups, nbooks),
 | 
			
		||||
                          dtype=get_int_dtype(bits),
 | 
			
		||||
                          device=device)
 | 
			
		||||
 | 
			
		||||
    codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
 | 
			
		||||
                            dtype=torch.float16,
 | 
			
		||||
                            device=device)
 | 
			
		||||
 | 
			
		||||
    scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
 | 
			
		||||
 | 
			
		||||
    # for comparison to just a pytorch mult.
 | 
			
		||||
    weights = torch.randn((n, k), dtype=torch.float16, device=device)
 | 
			
		||||
 | 
			
		||||
    start_event = torch.cuda.Event(enable_timing=True)
 | 
			
		||||
    end_event = torch.cuda.Event(enable_timing=True)
 | 
			
		||||
 | 
			
		||||
    start_event.record()
 | 
			
		||||
 | 
			
		||||
    if method is torch_mult:
 | 
			
		||||
        for i in range(num_calls):
 | 
			
		||||
            torch_mult(input, weights, scales)
 | 
			
		||||
    else:
 | 
			
		||||
        for i in range(num_calls):
 | 
			
		||||
            method(input, codes, codebooks, scales, parts, None)
 | 
			
		||||
 | 
			
		||||
    end_event.record()
 | 
			
		||||
    end_event.synchronize()
 | 
			
		||||
 | 
			
		||||
    dur_ms = start_event.elapsed_time(end_event) / num_calls
 | 
			
		||||
    return dur_ms
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    sys.exit(main())
 | 
			
		||||
							
								
								
									
										233
									
								
								benchmarks/kernels/benchmark_marlin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								benchmarks/kernels/benchmark_marlin.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,233 @@
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils.benchmark as benchmark
 | 
			
		||||
from benchmark_shapes import WEIGHT_SHAPES
 | 
			
		||||
 | 
			
		||||
from vllm import _custom_ops as ops
 | 
			
		||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
 | 
			
		||||
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
 | 
			
		||||
    GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
 | 
			
		||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
 | 
			
		||||
    GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
 | 
			
		||||
    GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
 | 
			
		||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
 | 
			
		||||
    MarlinWorkspace, marlin_24_quantize, marlin_quantize)
 | 
			
		||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
 | 
			
		||||
    gptq_pack, quantize_weights, sort_weights)
 | 
			
		||||
 | 
			
		||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
 | 
			
		||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
 | 
			
		||||
 | 
			
		||||
ACT_ORDER_OPTS = [False, True]
 | 
			
		||||
K_FULL_OPTS = [False, True]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
 | 
			
		||||
              size_m, size_k, size_n):
 | 
			
		||||
    label = "Quant Matmul"
 | 
			
		||||
 | 
			
		||||
    sub_label = ("{}, act={} k_full={}, b={}, g={}, "
 | 
			
		||||
                 "MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
 | 
			
		||||
                                         group_size, size_m, size_k, size_n))
 | 
			
		||||
 | 
			
		||||
    print(f"Testing: {sub_label}")
 | 
			
		||||
 | 
			
		||||
    a = torch.randn(size_m, size_k).to(torch.half).cuda()
 | 
			
		||||
    b = torch.rand(size_k, size_n).to(torch.half).cuda()
 | 
			
		||||
 | 
			
		||||
    a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
 | 
			
		||||
 | 
			
		||||
    # Marlin quant
 | 
			
		||||
    (
 | 
			
		||||
        marlin_w_ref,
 | 
			
		||||
        marlin_q_w,
 | 
			
		||||
        marlin_s,
 | 
			
		||||
        marlin_g_idx,
 | 
			
		||||
        marlin_sort_indices,
 | 
			
		||||
        marlin_rand_perm,
 | 
			
		||||
    ) = marlin_quantize(b, num_bits, group_size, act_order)
 | 
			
		||||
 | 
			
		||||
    # Marlin_24 quant
 | 
			
		||||
    (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
 | 
			
		||||
     marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
 | 
			
		||||
 | 
			
		||||
    # GPTQ quant
 | 
			
		||||
    (w_ref, q_w, s, g_idx,
 | 
			
		||||
     rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
 | 
			
		||||
    q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
 | 
			
		||||
 | 
			
		||||
    # For act_order, sort the "weights" and "g_idx"
 | 
			
		||||
    # so that group ids are increasing
 | 
			
		||||
    repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
 | 
			
		||||
    if act_order:
 | 
			
		||||
        (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
 | 
			
		||||
 | 
			
		||||
    # Prepare
 | 
			
		||||
    marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
 | 
			
		||||
                                       GPTQ_MARLIN_MAX_PARALLEL)
 | 
			
		||||
 | 
			
		||||
    marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
 | 
			
		||||
                                          GPTQ_MARLIN_24_MAX_PARALLEL)
 | 
			
		||||
 | 
			
		||||
    globals = {
 | 
			
		||||
        # Gen params
 | 
			
		||||
        "num_bits": num_bits,
 | 
			
		||||
        "group_size": group_size,
 | 
			
		||||
        "size_m": size_m,
 | 
			
		||||
        "size_n": size_n,
 | 
			
		||||
        "size_k": size_k,
 | 
			
		||||
        "a": a,
 | 
			
		||||
        "a_tmp": a_tmp,
 | 
			
		||||
        # Marlin params
 | 
			
		||||
        "marlin_w_ref": marlin_w_ref,
 | 
			
		||||
        "marlin_q_w": marlin_q_w,
 | 
			
		||||
        "marlin_s": marlin_s,
 | 
			
		||||
        "marlin_g_idx": marlin_g_idx,
 | 
			
		||||
        "marlin_sort_indices": marlin_sort_indices,
 | 
			
		||||
        "marlin_rand_perm": marlin_rand_perm,
 | 
			
		||||
        "marlin_workspace": marlin_workspace,
 | 
			
		||||
        "is_k_full": is_k_full,
 | 
			
		||||
        # Marlin_24 params
 | 
			
		||||
        "marlin_24_w_ref": marlin_24_w_ref,
 | 
			
		||||
        "marlin_24_q_w_comp": marlin_24_q_w_comp,
 | 
			
		||||
        "marlin_24_meta": marlin_24_meta,
 | 
			
		||||
        "marlin_24_s": marlin_24_s,
 | 
			
		||||
        "marlin_24_workspace": marlin_24_workspace,
 | 
			
		||||
        # GPTQ params
 | 
			
		||||
        "q_w_gptq": q_w_gptq,
 | 
			
		||||
        "repack_sort_indices": repack_sort_indices,
 | 
			
		||||
        # Kernels
 | 
			
		||||
        "gptq_marlin_gemm": ops.gptq_marlin_gemm,
 | 
			
		||||
        "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
 | 
			
		||||
        "gptq_marlin_repack": ops.gptq_marlin_repack,
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    min_run_time = 1
 | 
			
		||||
 | 
			
		||||
    # Warmup pytorch
 | 
			
		||||
    for i in range(5):
 | 
			
		||||
        torch.matmul(a, marlin_w_ref)
 | 
			
		||||
 | 
			
		||||
    results.append(
 | 
			
		||||
        benchmark.Timer(
 | 
			
		||||
            stmt="torch.matmul(a, marlin_w_ref)",
 | 
			
		||||
            globals=globals,
 | 
			
		||||
            label=label,
 | 
			
		||||
            sub_label=sub_label,
 | 
			
		||||
            description="pytorch_gemm",
 | 
			
		||||
        ).blocked_autorange(min_run_time=min_run_time))
 | 
			
		||||
 | 
			
		||||
    results.append(
 | 
			
		||||
        benchmark.Timer(
 | 
			
		||||
            stmt=
 | 
			
		||||
            "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)",  # noqa: E501
 | 
			
		||||
            globals=globals,
 | 
			
		||||
            label=label,
 | 
			
		||||
            sub_label=sub_label,
 | 
			
		||||
            description="gptq_marlin_gemm",
 | 
			
		||||
        ).blocked_autorange(min_run_time=min_run_time))
 | 
			
		||||
 | 
			
		||||
    if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
 | 
			
		||||
            and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
 | 
			
		||||
        results.append(
 | 
			
		||||
            benchmark.Timer(
 | 
			
		||||
                stmt=
 | 
			
		||||
                "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)",  # noqa: E501
 | 
			
		||||
                globals=globals,
 | 
			
		||||
                label=label,
 | 
			
		||||
                sub_label=sub_label,
 | 
			
		||||
                description="gptq_marlin_24_gemm",
 | 
			
		||||
            ).blocked_autorange(min_run_time=min_run_time))
 | 
			
		||||
 | 
			
		||||
    results.append(
 | 
			
		||||
        benchmark.Timer(
 | 
			
		||||
            stmt=
 | 
			
		||||
            "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)",  # noqa: E501
 | 
			
		||||
            globals=globals,
 | 
			
		||||
            label=label,
 | 
			
		||||
            sub_label=sub_label,
 | 
			
		||||
            description="gptq_marlin_repack",
 | 
			
		||||
        ).blocked_autorange(min_run_time=min_run_time))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args):
 | 
			
		||||
    print("Benchmarking models:")
 | 
			
		||||
    for i, model in enumerate(args.models):
 | 
			
		||||
        print(f"[{i}]  {model}")
 | 
			
		||||
 | 
			
		||||
    results = []
 | 
			
		||||
 | 
			
		||||
    for model in args.models:
 | 
			
		||||
        for layer in WEIGHT_SHAPES[model]:
 | 
			
		||||
            size_k = layer[0]
 | 
			
		||||
            size_n = layer[1]
 | 
			
		||||
 | 
			
		||||
            if len(args.limit_k) > 0 and size_k not in args.limit_k:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if len(args.limit_n) > 0 and size_n not in args.limit_n:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            for act_order in ACT_ORDER_OPTS:
 | 
			
		||||
                if len(args.limit_act_order
 | 
			
		||||
                       ) > 0 and act_order not in args.limit_act_order:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                for is_k_full in K_FULL_OPTS:
 | 
			
		||||
                    if len(args.limit_k_full
 | 
			
		||||
                           ) > 0 and is_k_full not in args.limit_k_full:
 | 
			
		||||
                        continue
 | 
			
		||||
 | 
			
		||||
                    for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
 | 
			
		||||
                        if len(args.limit_num_bits
 | 
			
		||||
                               ) > 0 and num_bits not in args.limit_num_bits:
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
 | 
			
		||||
                            if len(
 | 
			
		||||
                                    args.limit_group_size
 | 
			
		||||
                            ) > 0 and group_size not in args.limit_group_size:
 | 
			
		||||
                                continue
 | 
			
		||||
 | 
			
		||||
                            # For act_order, the group_size must be less than
 | 
			
		||||
                            # size_k
 | 
			
		||||
                            if act_order and (group_size == size_k
 | 
			
		||||
                                              or group_size == -1):
 | 
			
		||||
                                continue
 | 
			
		||||
 | 
			
		||||
                            for size_m in args.batch_sizes:
 | 
			
		||||
                                bench_run(results, model, act_order, is_k_full,
 | 
			
		||||
                                          num_bits, group_size, size_m, size_k,
 | 
			
		||||
                                          size_n)
 | 
			
		||||
 | 
			
		||||
    compare = benchmark.Compare(results)
 | 
			
		||||
    compare.print()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# For quick benchmarking use:
 | 
			
		||||
#   python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
 | 
			
		||||
#
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="Benchmark Marlin across specified models/shapes/batches")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--models",
 | 
			
		||||
        nargs="+",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=DEFAULT_MODELS,
 | 
			
		||||
        choices=WEIGHT_SHAPES.keys(),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--batch-sizes",
 | 
			
		||||
                        nargs="+",
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        default=DEFAULT_BATCH_SIZES)
 | 
			
		||||
    parser.add_argument("--limit-k", nargs="+", type=int, default=[])
 | 
			
		||||
    parser.add_argument("--limit-n", nargs="+", type=int, default=[])
 | 
			
		||||
    parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
 | 
			
		||||
    parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
 | 
			
		||||
    parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
 | 
			
		||||
    parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
@ -1,182 +0,0 @@
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import triton
 | 
			
		||||
 | 
			
		||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
 | 
			
		||||
                                                  get_config_file_name)
 | 
			
		||||
 | 
			
		||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    method = fused_moe
 | 
			
		||||
    for bs in [
 | 
			
		||||
            1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
 | 
			
		||||
            2048, 3072, 4096
 | 
			
		||||
    ]:
 | 
			
		||||
        run_grid(bs, method=method)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_grid(bs, method):
 | 
			
		||||
    d_model = 4096
 | 
			
		||||
    num_total_experts = 8
 | 
			
		||||
    top_k = 2
 | 
			
		||||
    tp_size = 2
 | 
			
		||||
    model_intermediate_size = 14336
 | 
			
		||||
    num_layers = 32
 | 
			
		||||
    num_calls = 100
 | 
			
		||||
 | 
			
		||||
    num_warmup_trials = 1
 | 
			
		||||
    num_trials = 1
 | 
			
		||||
 | 
			
		||||
    configs = []
 | 
			
		||||
    if bs <= 16:
 | 
			
		||||
        BLOCK_SIZES_M = [16]
 | 
			
		||||
    elif bs <= 32:
 | 
			
		||||
        BLOCK_SIZES_M = [16, 32]
 | 
			
		||||
    elif bs <= 64:
 | 
			
		||||
        BLOCK_SIZES_M = [16, 32, 64]
 | 
			
		||||
    elif bs <= 128:
 | 
			
		||||
        BLOCK_SIZES_M = [16, 32, 64, 128]
 | 
			
		||||
    else:
 | 
			
		||||
        BLOCK_SIZES_M = [16, 32, 64, 128, 256]
 | 
			
		||||
 | 
			
		||||
    for block_size_n in [32, 64, 128, 256]:
 | 
			
		||||
        for block_size_m in BLOCK_SIZES_M:
 | 
			
		||||
            for block_size_k in [64, 128, 256]:
 | 
			
		||||
                for group_size_m in [1, 16, 32, 64]:
 | 
			
		||||
                    for num_warps in [4, 8]:
 | 
			
		||||
                        configs.append({
 | 
			
		||||
                            "BLOCK_SIZE_M": block_size_m,
 | 
			
		||||
                            "BLOCK_SIZE_N": block_size_n,
 | 
			
		||||
                            "BLOCK_SIZE_K": block_size_k,
 | 
			
		||||
                            "GROUP_SIZE_M": group_size_m,
 | 
			
		||||
                            "num_warps": num_warps,
 | 
			
		||||
                            "num_stages": 4,
 | 
			
		||||
                        })
 | 
			
		||||
 | 
			
		||||
    best_config = None
 | 
			
		||||
    best_time_us = 1e20
 | 
			
		||||
 | 
			
		||||
    for config in configs:
 | 
			
		||||
        print(f'{tp_size=} {bs=}')
 | 
			
		||||
        print(f'{config}')
 | 
			
		||||
        # warmup
 | 
			
		||||
        print('warming up')
 | 
			
		||||
        try:
 | 
			
		||||
            for _ in range(num_warmup_trials):
 | 
			
		||||
                run_timing(
 | 
			
		||||
                    num_calls=num_calls,
 | 
			
		||||
                    bs=bs,
 | 
			
		||||
                    d_model=d_model,
 | 
			
		||||
                    num_total_experts=num_total_experts,
 | 
			
		||||
                    top_k=top_k,
 | 
			
		||||
                    tp_size=tp_size,
 | 
			
		||||
                    model_intermediate_size=model_intermediate_size,
 | 
			
		||||
                    method=method,
 | 
			
		||||
                    config=config,
 | 
			
		||||
                )
 | 
			
		||||
        except triton.runtime.autotuner.OutOfResources:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # trial
 | 
			
		||||
        print('benchmarking')
 | 
			
		||||
        for _ in range(num_trials):
 | 
			
		||||
            kernel_dur_ms = run_timing(
 | 
			
		||||
                num_calls=num_calls,
 | 
			
		||||
                bs=bs,
 | 
			
		||||
                d_model=d_model,
 | 
			
		||||
                num_total_experts=num_total_experts,
 | 
			
		||||
                top_k=top_k,
 | 
			
		||||
                tp_size=tp_size,
 | 
			
		||||
                model_intermediate_size=model_intermediate_size,
 | 
			
		||||
                method=method,
 | 
			
		||||
                config=config,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            kernel_dur_us = 1000 * kernel_dur_ms
 | 
			
		||||
            model_dur_ms = kernel_dur_ms * num_layers
 | 
			
		||||
 | 
			
		||||
            if kernel_dur_us < best_time_us:
 | 
			
		||||
                best_config = config
 | 
			
		||||
                best_time_us = kernel_dur_us
 | 
			
		||||
 | 
			
		||||
            print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
 | 
			
		||||
                  f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
 | 
			
		||||
                  f'{d_model=} {model_intermediate_size=} {num_layers=}')
 | 
			
		||||
 | 
			
		||||
    print("best_time_us", best_time_us)
 | 
			
		||||
    print("best_config", best_config)
 | 
			
		||||
 | 
			
		||||
    # holds Dict[str, Dict[str, int]]
 | 
			
		||||
    filename = get_config_file_name(num_total_experts,
 | 
			
		||||
                                    model_intermediate_size // tp_size)
 | 
			
		||||
    print(f"writing config to file {filename}")
 | 
			
		||||
    existing_content = {}
 | 
			
		||||
    if os.path.exists(filename):
 | 
			
		||||
        with open(filename, "r") as f:
 | 
			
		||||
            existing_content = json.load(f)
 | 
			
		||||
    existing_content[str(bs)] = best_config
 | 
			
		||||
    with open(filename, "w") as f:
 | 
			
		||||
        json.dump(existing_content, f, indent=4)
 | 
			
		||||
        f.write("\n")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
 | 
			
		||||
               top_k: int, tp_size: int, model_intermediate_size: int, method,
 | 
			
		||||
               config) -> float:
 | 
			
		||||
    shard_intermediate_size = model_intermediate_size // tp_size
 | 
			
		||||
 | 
			
		||||
    hidden_states = torch.rand(
 | 
			
		||||
        (bs, d_model),
 | 
			
		||||
        device="cuda:0",
 | 
			
		||||
        dtype=torch.bfloat16,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    ws = torch.rand(
 | 
			
		||||
        (num_total_experts, 2 * shard_intermediate_size, d_model),
 | 
			
		||||
        device=hidden_states.device,
 | 
			
		||||
        dtype=hidden_states.dtype,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    w2s = torch.rand(
 | 
			
		||||
        (num_total_experts, d_model, shard_intermediate_size),
 | 
			
		||||
        device=hidden_states.device,
 | 
			
		||||
        dtype=hidden_states.dtype,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    gating_output = F.softmax(torch.rand(
 | 
			
		||||
        (num_calls, bs, num_total_experts),
 | 
			
		||||
        device=hidden_states.device,
 | 
			
		||||
        dtype=torch.float32,
 | 
			
		||||
    ),
 | 
			
		||||
                              dim=-1)
 | 
			
		||||
 | 
			
		||||
    start_event = torch.cuda.Event(enable_timing=True)
 | 
			
		||||
    end_event = torch.cuda.Event(enable_timing=True)
 | 
			
		||||
 | 
			
		||||
    start_event.record()
 | 
			
		||||
    for i in range(num_calls):
 | 
			
		||||
        hidden_states = method(
 | 
			
		||||
            hidden_states=hidden_states,
 | 
			
		||||
            w1=ws,
 | 
			
		||||
            w2=w2s,
 | 
			
		||||
            gating_output=gating_output[i],
 | 
			
		||||
            topk=2,
 | 
			
		||||
            renormalize=True,
 | 
			
		||||
            inplace=True,
 | 
			
		||||
            override_config=config,
 | 
			
		||||
        )
 | 
			
		||||
    end_event.record()
 | 
			
		||||
    end_event.synchronize()
 | 
			
		||||
 | 
			
		||||
    dur_ms = start_event.elapsed_time(end_event) / num_calls
 | 
			
		||||
    return dur_ms
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    sys.exit(main())
 | 
			
		||||
							
								
								
									
										322
									
								
								benchmarks/kernels/benchmark_moe.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										322
									
								
								benchmarks/kernels/benchmark_moe.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,322 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import time
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import Any, Dict, List, Tuple
 | 
			
		||||
 | 
			
		||||
import ray
 | 
			
		||||
import torch
 | 
			
		||||
import triton
 | 
			
		||||
from ray.experimental.tqdm_ray import tqdm
 | 
			
		||||
from transformers import AutoConfig
 | 
			
		||||
 | 
			
		||||
from vllm.model_executor.layers.fused_moe.fused_moe import *
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def benchmark_config(
 | 
			
		||||
    config: Dict[str, int],
 | 
			
		||||
    num_tokens: int,
 | 
			
		||||
    num_experts: int,
 | 
			
		||||
    shard_intermediate_size: int,
 | 
			
		||||
    hidden_size: int,
 | 
			
		||||
    topk: int,
 | 
			
		||||
    dtype: torch.dtype,
 | 
			
		||||
    use_fp8: bool,
 | 
			
		||||
    num_iters: int = 100,
 | 
			
		||||
) -> float:
 | 
			
		||||
    init_dtype = torch.float16 if use_fp8 else dtype
 | 
			
		||||
    x = torch.randn(num_tokens, hidden_size, dtype=dtype)
 | 
			
		||||
    w1 = torch.randn(num_experts,
 | 
			
		||||
                     shard_intermediate_size,
 | 
			
		||||
                     hidden_size,
 | 
			
		||||
                     dtype=init_dtype)
 | 
			
		||||
    w2 = torch.randn(num_experts,
 | 
			
		||||
                     hidden_size,
 | 
			
		||||
                     shard_intermediate_size // 2,
 | 
			
		||||
                     dtype=init_dtype)
 | 
			
		||||
    gating_output = torch.randn(num_iters,
 | 
			
		||||
                                num_tokens,
 | 
			
		||||
                                num_experts,
 | 
			
		||||
                                dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    w1_scale = None
 | 
			
		||||
    w2_scale = None
 | 
			
		||||
    a1_scale = None
 | 
			
		||||
    a2_scale = None
 | 
			
		||||
    if use_fp8:
 | 
			
		||||
        w1_scale = torch.randn(num_experts, dtype=torch.float32)
 | 
			
		||||
        w2_scale = torch.randn(num_experts, dtype=torch.float32)
 | 
			
		||||
        a1_scale = torch.randn(1, dtype=torch.float32)
 | 
			
		||||
        a2_scale = torch.randn(1, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
        w1 = w1.to(torch.float8_e4m3fn)
 | 
			
		||||
        w2 = w2.to(torch.float8_e4m3fn)
 | 
			
		||||
 | 
			
		||||
    input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    def prepare(i: int):
 | 
			
		||||
        input_gating.copy_(gating_output[i])
 | 
			
		||||
 | 
			
		||||
    def run():
 | 
			
		||||
        fused_moe(
 | 
			
		||||
            x,
 | 
			
		||||
            w1,
 | 
			
		||||
            w2,
 | 
			
		||||
            input_gating,
 | 
			
		||||
            topk,
 | 
			
		||||
            renormalize=True,
 | 
			
		||||
            inplace=True,
 | 
			
		||||
            override_config=config,
 | 
			
		||||
            use_fp8=use_fp8,
 | 
			
		||||
            w1_scale=w1_scale,
 | 
			
		||||
            w2_scale=w2_scale,
 | 
			
		||||
            a1_scale=a1_scale,
 | 
			
		||||
            a2_scale=a2_scale,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # JIT compilation & warmup
 | 
			
		||||
    run()
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    # Capture 10 invocations with CUDA graph
 | 
			
		||||
    graph = torch.cuda.CUDAGraph()
 | 
			
		||||
    with torch.cuda.graph(graph):
 | 
			
		||||
        for _ in range(10):
 | 
			
		||||
            run()
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    # Warmup
 | 
			
		||||
    for _ in range(5):
 | 
			
		||||
        graph.replay()
 | 
			
		||||
    torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
    start_event = torch.cuda.Event(enable_timing=True)
 | 
			
		||||
    end_event = torch.cuda.Event(enable_timing=True)
 | 
			
		||||
 | 
			
		||||
    latencies = []
 | 
			
		||||
    for i in range(num_iters):
 | 
			
		||||
        prepare(i)
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        start_event.record()
 | 
			
		||||
        graph.replay()
 | 
			
		||||
        end_event.record()
 | 
			
		||||
        end_event.synchronize()
 | 
			
		||||
        latencies.append(start_event.elapsed_time(end_event))
 | 
			
		||||
    avg = sum(latencies) / (num_iters * 10) * 1000  # us
 | 
			
		||||
    graph.reset()
 | 
			
		||||
    return avg
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
 | 
			
		||||
    # Reduced search space for faster tuning.
 | 
			
		||||
    # TODO(woosuk): Increase the search space and use a performance model to
 | 
			
		||||
    # prune the search space.
 | 
			
		||||
    configs = []
 | 
			
		||||
    for num_stages in [2, 3, 4, 5]:
 | 
			
		||||
        for block_m in [16, 32, 64, 128, 256]:
 | 
			
		||||
            for block_k in [64, 128, 256]:
 | 
			
		||||
                for block_n in [32, 64, 128, 256]:
 | 
			
		||||
                    for num_warps in [4, 8]:
 | 
			
		||||
                        for group_size in [1, 16, 32, 64]:
 | 
			
		||||
                            configs.append({
 | 
			
		||||
                                "BLOCK_SIZE_M": block_m,
 | 
			
		||||
                                "BLOCK_SIZE_N": block_n,
 | 
			
		||||
                                "BLOCK_SIZE_K": block_k,
 | 
			
		||||
                                "GROUP_SIZE_M": group_size,
 | 
			
		||||
                                "num_warps": num_warps,
 | 
			
		||||
                                "num_stages": num_stages,
 | 
			
		||||
                            })
 | 
			
		||||
    return configs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ray.remote(num_gpus=1)
 | 
			
		||||
class BenchmarkWorker:
 | 
			
		||||
 | 
			
		||||
    def __init__(self, seed: int) -> None:
 | 
			
		||||
        torch.set_default_device("cuda")
 | 
			
		||||
        torch.cuda.manual_seed_all(seed)
 | 
			
		||||
        self.seed = seed
 | 
			
		||||
 | 
			
		||||
    def benchmark(
 | 
			
		||||
        self,
 | 
			
		||||
        num_tokens: int,
 | 
			
		||||
        num_experts: int,
 | 
			
		||||
        shard_intermediate_size: int,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        topk: int,
 | 
			
		||||
        dtype: torch.dtype,
 | 
			
		||||
        use_fp8: bool,
 | 
			
		||||
    ) -> Tuple[Dict[str, int], float]:
 | 
			
		||||
        torch.cuda.manual_seed_all(self.seed)
 | 
			
		||||
 | 
			
		||||
        dtype_str = "float8" if use_fp8 else None
 | 
			
		||||
        # NOTE(woosuk): The current naming convention uses w2.shape[2], which
 | 
			
		||||
        # is the intermediate size after silu_and_mul.
 | 
			
		||||
        op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
 | 
			
		||||
                                    dtype_str)
 | 
			
		||||
        if op_config is None:
 | 
			
		||||
            config = get_default_config(num_tokens, num_experts,
 | 
			
		||||
                                        shard_intermediate_size, hidden_size,
 | 
			
		||||
                                        topk, dtype_str)
 | 
			
		||||
        else:
 | 
			
		||||
            config = op_config[min(op_config.keys(),
 | 
			
		||||
                                   key=lambda x: abs(x - num_tokens))]
 | 
			
		||||
        kernel_time = benchmark_config(config, num_tokens, num_experts,
 | 
			
		||||
                                       shard_intermediate_size, hidden_size,
 | 
			
		||||
                                       topk, dtype, use_fp8)
 | 
			
		||||
        return config, kernel_time
 | 
			
		||||
 | 
			
		||||
    def tune(
 | 
			
		||||
        self,
 | 
			
		||||
        num_tokens: int,
 | 
			
		||||
        num_experts: int,
 | 
			
		||||
        shard_intermediate_size: int,
 | 
			
		||||
        hidden_size: int,
 | 
			
		||||
        topk: int,
 | 
			
		||||
        dtype: torch.dtype,
 | 
			
		||||
        use_fp8: bool,
 | 
			
		||||
        search_space: List[Dict[str, int]],
 | 
			
		||||
    ) -> Dict[str, int]:
 | 
			
		||||
        best_config = None
 | 
			
		||||
        best_time = float("inf")
 | 
			
		||||
        for config in tqdm(search_space):
 | 
			
		||||
            try:
 | 
			
		||||
                kernel_time = benchmark_config(config,
 | 
			
		||||
                                               num_tokens,
 | 
			
		||||
                                               num_experts,
 | 
			
		||||
                                               shard_intermediate_size,
 | 
			
		||||
                                               hidden_size,
 | 
			
		||||
                                               topk,
 | 
			
		||||
                                               dtype,
 | 
			
		||||
                                               use_fp8,
 | 
			
		||||
                                               num_iters=10)
 | 
			
		||||
            except triton.runtime.autotuner.OutOfResources:
 | 
			
		||||
                # Some configurations may be invalid and fail to compile.
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if kernel_time < best_time:
 | 
			
		||||
                best_time = kernel_time
 | 
			
		||||
                best_config = config
 | 
			
		||||
        now = datetime.now()
 | 
			
		||||
        print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
 | 
			
		||||
        return best_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sort_config(config: Dict[str, int]) -> Dict[str, int]:
 | 
			
		||||
    return {
 | 
			
		||||
        "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
 | 
			
		||||
        "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
 | 
			
		||||
        "BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
 | 
			
		||||
        "GROUP_SIZE_M": config["GROUP_SIZE_M"],
 | 
			
		||||
        "num_warps": config["num_warps"],
 | 
			
		||||
        "num_stages": config["num_stages"],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_configs(
 | 
			
		||||
    configs: Dict[int, Dict[str, int]],
 | 
			
		||||
    num_experts: int,
 | 
			
		||||
    shard_intermediate_size: int,
 | 
			
		||||
    hidden_size: int,
 | 
			
		||||
    topk: int,
 | 
			
		||||
    dtype: torch.dtype,
 | 
			
		||||
    use_fp8: bool,
 | 
			
		||||
) -> None:
 | 
			
		||||
    dtype_str = "float8" if use_fp8 else None
 | 
			
		||||
    # NOTE(woosuk): The current naming convention uses w2.shape[2], which
 | 
			
		||||
    # is the intermediate size after silu_and_mul.
 | 
			
		||||
    filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
 | 
			
		||||
                                    dtype_str)
 | 
			
		||||
    print(f"Writing best config to {filename}...")
 | 
			
		||||
    with open(filename, "w") as f:
 | 
			
		||||
        json.dump(configs, f, indent=4)
 | 
			
		||||
        f.write("\n")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args: argparse.Namespace):
 | 
			
		||||
    print(args)
 | 
			
		||||
 | 
			
		||||
    config = AutoConfig.from_pretrained(args.model)
 | 
			
		||||
    if config.architectures[0] == "DbrxForCausalLM":
 | 
			
		||||
        E = config.ffn_config.moe_num_experts
 | 
			
		||||
        topk = config.ffn_config.moe_top_k
 | 
			
		||||
        intermediate_size = config.ffn_config.ffn_hidden_size
 | 
			
		||||
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
 | 
			
		||||
    else:
 | 
			
		||||
        # Default: Mixtral.
 | 
			
		||||
        E = config.num_local_experts
 | 
			
		||||
        topk = config.num_experts_per_tok
 | 
			
		||||
        intermediate_size = config.intermediate_size
 | 
			
		||||
        shard_intermediate_size = 2 * intermediate_size // args.tp_size
 | 
			
		||||
 | 
			
		||||
    hidden_size = config.hidden_size
 | 
			
		||||
    dtype = config.torch_dtype
 | 
			
		||||
    use_fp8 = args.dtype == "fp8"
 | 
			
		||||
 | 
			
		||||
    if args.batch_size is None:
 | 
			
		||||
        batch_sizes = [
 | 
			
		||||
            1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
 | 
			
		||||
            2048, 3072, 4096
 | 
			
		||||
        ]
 | 
			
		||||
    else:
 | 
			
		||||
        batch_sizes = [args.batch_size]
 | 
			
		||||
 | 
			
		||||
    ray.init()
 | 
			
		||||
    num_gpus = int(ray.available_resources()["GPU"])
 | 
			
		||||
    workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
 | 
			
		||||
 | 
			
		||||
    def _distribute(method: str, inputs: List[Any]) -> List[Any]:
 | 
			
		||||
        outputs = []
 | 
			
		||||
        worker_idx = 0
 | 
			
		||||
        for input_args in inputs:
 | 
			
		||||
            worker = workers[worker_idx]
 | 
			
		||||
            worker_method = getattr(worker, method)
 | 
			
		||||
            output = worker_method.remote(*input_args)
 | 
			
		||||
            outputs.append(output)
 | 
			
		||||
            worker_idx = (worker_idx + 1) % num_gpus
 | 
			
		||||
        return ray.get(outputs)
 | 
			
		||||
 | 
			
		||||
    if args.tune:
 | 
			
		||||
        search_space = get_configs_compute_bound()
 | 
			
		||||
        print(f"Start tuning over {len(search_space)} configurations...")
 | 
			
		||||
 | 
			
		||||
        start = time.time()
 | 
			
		||||
        configs = _distribute(
 | 
			
		||||
            "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
 | 
			
		||||
                      topk, dtype, use_fp8, search_space)
 | 
			
		||||
                     for batch_size in batch_sizes])
 | 
			
		||||
        best_configs = {
 | 
			
		||||
            M: sort_config(config)
 | 
			
		||||
            for M, config in zip(batch_sizes, configs)
 | 
			
		||||
        }
 | 
			
		||||
        save_configs(best_configs, E, shard_intermediate_size, hidden_size,
 | 
			
		||||
                     topk, dtype, use_fp8)
 | 
			
		||||
        end = time.time()
 | 
			
		||||
        print(f"Tuning took {end - start:.2f} seconds")
 | 
			
		||||
    else:
 | 
			
		||||
        outputs = _distribute("benchmark",
 | 
			
		||||
                              [(batch_size, E, shard_intermediate_size,
 | 
			
		||||
                                hidden_size, topk, dtype, use_fp8)
 | 
			
		||||
                               for batch_size in batch_sizes])
 | 
			
		||||
 | 
			
		||||
        for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
 | 
			
		||||
            print(f"Batch size: {batch_size}, config: {config}")
 | 
			
		||||
            print(f"Kernel time: {kernel_time:.2f} us")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--model",
 | 
			
		||||
                        type=str,
 | 
			
		||||
                        default="mistralai/Mixtral-8x7B-Instruct-v0.1")
 | 
			
		||||
    parser.add_argument("--tp-size", "-tp", type=int, default=2)
 | 
			
		||||
    parser.add_argument("--dtype",
 | 
			
		||||
                        type=str,
 | 
			
		||||
                        choices=["auto", "fp8"],
 | 
			
		||||
                        default="auto")
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=0)
 | 
			
		||||
    parser.add_argument("--batch-size", type=int, required=False)
 | 
			
		||||
    parser.add_argument("--tune", action="store_true")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    main(args)
 | 
			
		||||
@ -5,7 +5,7 @@ from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from vllm._C import ops
 | 
			
		||||
from vllm import _custom_ops as ops
 | 
			
		||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
 | 
			
		||||
 | 
			
		||||
NUM_BLOCKS = 1024
 | 
			
		||||
@ -16,7 +16,7 @@ PARTITION_SIZE = 512
 | 
			
		||||
def main(
 | 
			
		||||
    version: str,
 | 
			
		||||
    num_seqs: int,
 | 
			
		||||
    context_len: int,
 | 
			
		||||
    seq_len: int,
 | 
			
		||||
    num_query_heads: int,
 | 
			
		||||
    num_kv_heads: int,
 | 
			
		||||
    head_size: int,
 | 
			
		||||
@ -48,12 +48,12 @@ def main(
 | 
			
		||||
                                   dtype=torch.float,
 | 
			
		||||
                                   device=device)
 | 
			
		||||
 | 
			
		||||
    context_lens = [context_len for _ in range(num_seqs)]
 | 
			
		||||
    max_context_len = max(context_lens)
 | 
			
		||||
    context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
 | 
			
		||||
    seq_lens = [seq_len for _ in range(num_seqs)]
 | 
			
		||||
    max_seq_len = max(seq_lens)
 | 
			
		||||
    seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
 | 
			
		||||
 | 
			
		||||
    # Create the block tables.
 | 
			
		||||
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
 | 
			
		||||
    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
 | 
			
		||||
    block_tables = []
 | 
			
		||||
    for _ in range(num_seqs):
 | 
			
		||||
        block_table = [
 | 
			
		||||
@ -77,8 +77,7 @@ def main(
 | 
			
		||||
    # Prepare for the paged attention kernel.
 | 
			
		||||
    output = torch.empty_like(query)
 | 
			
		||||
    if version == "v2":
 | 
			
		||||
        num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
 | 
			
		||||
                          PARTITION_SIZE)
 | 
			
		||||
        num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
 | 
			
		||||
        tmp_output = torch.empty(
 | 
			
		||||
            size=(num_seqs, num_query_heads, num_partitions, head_size),
 | 
			
		||||
            dtype=output.dtype,
 | 
			
		||||
@ -97,6 +96,9 @@ def main(
 | 
			
		||||
            torch.cuda.cudart().cudaProfilerStart()
 | 
			
		||||
        start_time = time.perf_counter()
 | 
			
		||||
 | 
			
		||||
        # Using default kv_scale
 | 
			
		||||
        kv_scale = 1.0
 | 
			
		||||
 | 
			
		||||
        for _ in range(num_iters):
 | 
			
		||||
            if version == "v1":
 | 
			
		||||
                ops.paged_attention_v1(
 | 
			
		||||
@ -107,11 +109,12 @@ def main(
 | 
			
		||||
                    num_kv_heads,
 | 
			
		||||
                    scale,
 | 
			
		||||
                    block_tables,
 | 
			
		||||
                    context_lens,
 | 
			
		||||
                    seq_lens,
 | 
			
		||||
                    block_size,
 | 
			
		||||
                    max_context_len,
 | 
			
		||||
                    max_seq_len,
 | 
			
		||||
                    alibi_slopes,
 | 
			
		||||
                    kv_cache_dtype,
 | 
			
		||||
                    kv_scale,
 | 
			
		||||
                )
 | 
			
		||||
            elif version == "v2":
 | 
			
		||||
                ops.paged_attention_v2(
 | 
			
		||||
@ -125,11 +128,12 @@ def main(
 | 
			
		||||
                    num_kv_heads,
 | 
			
		||||
                    scale,
 | 
			
		||||
                    block_tables,
 | 
			
		||||
                    context_lens,
 | 
			
		||||
                    seq_lens,
 | 
			
		||||
                    block_size,
 | 
			
		||||
                    max_context_len,
 | 
			
		||||
                    max_seq_len,
 | 
			
		||||
                    alibi_slopes,
 | 
			
		||||
                    kv_cache_dtype,
 | 
			
		||||
                    kv_scale,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Invalid version: {version}")
 | 
			
		||||
@ -161,12 +165,12 @@ if __name__ == '__main__':
 | 
			
		||||
                        choices=["v1", "v2"],
 | 
			
		||||
                        default="v2")
 | 
			
		||||
    parser.add_argument("--batch-size", type=int, default=8)
 | 
			
		||||
    parser.add_argument("--context-len", type=int, default=4096)
 | 
			
		||||
    parser.add_argument("--seq_len", type=int, default=4096)
 | 
			
		||||
    parser.add_argument("--num-query-heads", type=int, default=64)
 | 
			
		||||
    parser.add_argument("--num-kv-heads", type=int, default=8)
 | 
			
		||||
    parser.add_argument("--head-size",
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        choices=[64, 80, 96, 112, 128, 256],
 | 
			
		||||
                        choices=[64, 80, 96, 112, 128, 192, 256],
 | 
			
		||||
                        default=128)
 | 
			
		||||
    parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
 | 
			
		||||
    parser.add_argument("--use-alibi", action="store_true")
 | 
			
		||||
@ -179,11 +183,11 @@ if __name__ == '__main__':
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--kv-cache-dtype",
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=["auto", "fp8_e5m2"],
 | 
			
		||||
        choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
 | 
			
		||||
        default="auto",
 | 
			
		||||
        help=
 | 
			
		||||
        'Data type for kv cache storage. If "auto", will use model data type.')
 | 
			
		||||
    parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
 | 
			
		||||
        help="Data type for kv cache storage. If 'auto', will use model "
 | 
			
		||||
        "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
 | 
			
		||||
        "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    print(args)
 | 
			
		||||
 | 
			
		||||
@ -192,7 +196,7 @@ if __name__ == '__main__':
 | 
			
		||||
    main(
 | 
			
		||||
        version=args.version,
 | 
			
		||||
        num_seqs=args.batch_size,
 | 
			
		||||
        context_len=args.context_len,
 | 
			
		||||
        seq_len=args.seq_len,
 | 
			
		||||
        num_query_heads=args.num_query_heads,
 | 
			
		||||
        num_kv_heads=args.num_kv_heads,
 | 
			
		||||
        head_size=args.head_size,
 | 
			
		||||
 | 
			
		||||
@ -93,7 +93,7 @@ if __name__ == '__main__':
 | 
			
		||||
    parser.add_argument("--num-heads", type=int, default=8)
 | 
			
		||||
    parser.add_argument("--head-size",
 | 
			
		||||
                        type=int,
 | 
			
		||||
                        choices=[64, 80, 96, 112, 128, 256],
 | 
			
		||||
                        choices=[64, 80, 96, 112, 128, 192, 256],
 | 
			
		||||
                        default=128)
 | 
			
		||||
    parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
 | 
			
		||||
    parser.add_argument("--dtype",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										75
									
								
								benchmarks/kernels/benchmark_shapes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								benchmarks/kernels/benchmark_shapes.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,75 @@
 | 
			
		||||
WEIGHT_SHAPES = {
 | 
			
		||||
    "ideal": [[4 * 256 * 32, 256 * 32]],
 | 
			
		||||
    "mistralai/Mistral-7B-v0.1/TP1": [
 | 
			
		||||
        [4096, 6144],
 | 
			
		||||
        [4096, 4096],
 | 
			
		||||
        [4096, 28672],
 | 
			
		||||
        [14336, 4096],
 | 
			
		||||
    ],
 | 
			
		||||
    "mistralai/Mistral-7B-v0.1/TP2": [
 | 
			
		||||
        [4096, 3072],
 | 
			
		||||
        [2048, 4096],
 | 
			
		||||
        [4096, 14336],
 | 
			
		||||
        [7168, 4096],
 | 
			
		||||
    ],
 | 
			
		||||
    "mistralai/Mistral-7B-v0.1/TP4": [
 | 
			
		||||
        [4096, 1536],
 | 
			
		||||
        [1024, 4096],
 | 
			
		||||
        [4096, 7168],
 | 
			
		||||
        [3584, 4096],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-7b-hf/TP1": [
 | 
			
		||||
        [4096, 12288],
 | 
			
		||||
        [4096, 4096],
 | 
			
		||||
        [4096, 22016],
 | 
			
		||||
        [11008, 4096],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-7b-hf/TP2": [
 | 
			
		||||
        [4096, 6144],
 | 
			
		||||
        [2048, 4096],
 | 
			
		||||
        [4096, 11008],
 | 
			
		||||
        [5504, 4096],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-7b-hf/TP4": [
 | 
			
		||||
        [4096, 3072],
 | 
			
		||||
        [1024, 4096],
 | 
			
		||||
        [4096, 5504],
 | 
			
		||||
        [2752, 4096],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-13b-hf/TP1": [
 | 
			
		||||
        [5120, 15360],
 | 
			
		||||
        [5120, 5120],
 | 
			
		||||
        [5120, 27648],
 | 
			
		||||
        [13824, 5120],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-13b-hf/TP2": [
 | 
			
		||||
        [5120, 7680],
 | 
			
		||||
        [2560, 5120],
 | 
			
		||||
        [5120, 13824],
 | 
			
		||||
        [6912, 5120],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-13b-hf/TP4": [
 | 
			
		||||
        [5120, 3840],
 | 
			
		||||
        [1280, 5120],
 | 
			
		||||
        [5120, 6912],
 | 
			
		||||
        [3456, 5120],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-70b-hf/TP1": [
 | 
			
		||||
        [8192, 10240],
 | 
			
		||||
        [8192, 8192],
 | 
			
		||||
        [8192, 57344],
 | 
			
		||||
        [28672, 8192],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-70b-hf/TP2": [
 | 
			
		||||
        [8192, 5120],
 | 
			
		||||
        [4096, 8192],
 | 
			
		||||
        [8192, 28672],
 | 
			
		||||
        [14336, 8192],
 | 
			
		||||
    ],
 | 
			
		||||
    "meta-llama/Llama-2-70b-hf/TP4": [
 | 
			
		||||
        [8192, 2560],
 | 
			
		||||
        [2048, 8192],
 | 
			
		||||
        [8192, 14336],
 | 
			
		||||
        [7168, 8192],
 | 
			
		||||
    ],
 | 
			
		||||
}
 | 
			
		||||
@ -4,7 +4,7 @@ PORT=8000
 | 
			
		||||
MODEL=$1
 | 
			
		||||
TOKENS=$2
 | 
			
		||||
 | 
			
		||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
 | 
			
		||||
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
 | 
			
		||||
           -v $PWD/data:/data \
 | 
			
		||||
           ghcr.io/huggingface/text-generation-inference:1.4.0 \
 | 
			
		||||
           --model-id $MODEL \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										63
									
								
								benchmarks/overheads/benchmark_hashing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								benchmarks/overheads/benchmark_hashing.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,63 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import cProfile
 | 
			
		||||
import pstats
 | 
			
		||||
 | 
			
		||||
from vllm import LLM, SamplingParams
 | 
			
		||||
 | 
			
		||||
# A very long prompt, total number of tokens is about 15k.
 | 
			
		||||
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
 | 
			
		||||
               ] * 1000
 | 
			
		||||
LONG_PROMPT = ' '.join(LONG_PROMPT)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args):
 | 
			
		||||
    llm = LLM(
 | 
			
		||||
        model=args.model,
 | 
			
		||||
        enforce_eager=True,
 | 
			
		||||
        enable_prefix_caching=True,
 | 
			
		||||
        tensor_parallel_size=args.tensor_parallel_size,
 | 
			
		||||
        use_v2_block_manager=args.use_v2_block_manager,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
 | 
			
		||||
    profiler = cProfile.Profile()
 | 
			
		||||
 | 
			
		||||
    print("------warm up------")
 | 
			
		||||
    for i in range(3):
 | 
			
		||||
        output = llm.generate(LONG_PROMPT, sampling_params)
 | 
			
		||||
        print(output[0].outputs[0].text)
 | 
			
		||||
 | 
			
		||||
    print("------start generating------")
 | 
			
		||||
    for i in range(3):
 | 
			
		||||
        profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
 | 
			
		||||
                        globals(), locals())
 | 
			
		||||
 | 
			
		||||
    # analyze the runtime of hashing function
 | 
			
		||||
    stats = pstats.Stats(profiler)
 | 
			
		||||
    stats.sort_stats('cumulative')
 | 
			
		||||
    total_time = 0
 | 
			
		||||
    total_calls = 0
 | 
			
		||||
    for func in stats.stats:
 | 
			
		||||
        if 'hash_of_block' in func[2]:
 | 
			
		||||
            total_time = stats.stats[func][3]
 | 
			
		||||
            total_calls = stats.stats[func][0]
 | 
			
		||||
    percentage = (total_time / stats.total_tt) * 100
 | 
			
		||||
    print(f"Hashing took {total_time:.2f} seconds,"
 | 
			
		||||
          f"{percentage:.2f}% of the total runtime.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description='Benchmark the performance of hashing function in'
 | 
			
		||||
        'automatic prefix caching.')
 | 
			
		||||
    parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
 | 
			
		||||
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--output-len', type=int, default=10)
 | 
			
		||||
    parser.add_argument('--enable-prefix-caching',
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
                        help='enable prefix caching')
 | 
			
		||||
    parser.add_argument('--use-v2-block-manager',
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
                        help='Use BlockSpaceMangerV2')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
							
								
								
									
										90
									
								
								cmake/cpu_extension.cmake
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								cmake/cpu_extension.cmake
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
			
		||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Define environment variables for special configurations
 | 
			
		||||
#
 | 
			
		||||
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
 | 
			
		||||
    set(ENABLE_AVX512BF16 ON)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Check the compile flags
 | 
			
		||||
#
 | 
			
		||||
list(APPEND CXX_COMPILE_FLAGS
 | 
			
		||||
    "-fopenmp"
 | 
			
		||||
    "-DVLLM_CPU_EXTENSION")
 | 
			
		||||
 | 
			
		||||
execute_process(COMMAND cat /proc/cpuinfo
 | 
			
		||||
                RESULT_VARIABLE CPUINFO_RET
 | 
			
		||||
                OUTPUT_VARIABLE CPUINFO)
 | 
			
		||||
 | 
			
		||||
if (NOT CPUINFO_RET EQUAL 0)
 | 
			
		||||
    message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
function (find_isa CPUINFO TARGET OUT)
 | 
			
		||||
    string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
 | 
			
		||||
    if(NOT ISA_FOUND EQUAL -1)
 | 
			
		||||
        set(${OUT} ON PARENT_SCOPE)
 | 
			
		||||
    else()
 | 
			
		||||
        set(${OUT} OFF PARENT_SCOPE)
 | 
			
		||||
    endif()
 | 
			
		||||
endfunction()
 | 
			
		||||
 | 
			
		||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
 | 
			
		||||
 | 
			
		||||
if (AVX512_FOUND)
 | 
			
		||||
    list(APPEND CXX_COMPILE_FLAGS
 | 
			
		||||
        "-mavx512f"
 | 
			
		||||
        "-mavx512vl"
 | 
			
		||||
        "-mavx512bw"
 | 
			
		||||
        "-mavx512dq")
 | 
			
		||||
 | 
			
		||||
    find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
 | 
			
		||||
    if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
 | 
			
		||||
        if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
 | 
			
		||||
            CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
 | 
			
		||||
            list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
 | 
			
		||||
        else()
 | 
			
		||||
            message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
 | 
			
		||||
        endif()
 | 
			
		||||
    else()
 | 
			
		||||
        message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
 | 
			
		||||
    endif()
 | 
			
		||||
else()
 | 
			
		||||
    message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# Define extension targets
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
#
 | 
			
		||||
# _C extension
 | 
			
		||||
#
 | 
			
		||||
set(VLLM_EXT_SRC
 | 
			
		||||
    "csrc/cpu/activation.cpp"
 | 
			
		||||
    "csrc/cpu/attention.cpp"
 | 
			
		||||
    "csrc/cpu/cache.cpp"
 | 
			
		||||
    "csrc/cpu/layernorm.cpp"
 | 
			
		||||
    "csrc/cpu/pos_encoding.cpp"
 | 
			
		||||
    "csrc/cpu/torch_bindings.cpp")
 | 
			
		||||
 | 
			
		||||
define_gpu_extension_target(
 | 
			
		||||
    _C
 | 
			
		||||
    DESTINATION vllm
 | 
			
		||||
    LANGUAGE CXX
 | 
			
		||||
    SOURCES ${VLLM_EXT_SRC}
 | 
			
		||||
    COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
 | 
			
		||||
    USE_SABI 3
 | 
			
		||||
    WITH_SOABI
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
add_custom_target(default)
 | 
			
		||||
message(STATUS "Enabling C extension.")
 | 
			
		||||
add_dependencies(default _C)
 | 
			
		||||
@ -5,7 +5,7 @@
 | 
			
		||||
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
 | 
			
		||||
  file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
 | 
			
		||||
  set(Python_EXECUTABLE ${EXECUTABLE})
 | 
			
		||||
  find_package(Python COMPONENTS Interpreter Development.Module)
 | 
			
		||||
  find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
 | 
			
		||||
  if (NOT Python_FOUND)
 | 
			
		||||
    message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
 | 
			
		||||
  endif()
 | 
			
		||||
@ -99,7 +99,14 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
 | 
			
		||||
      "Failed to determine torch nvcc compiler flags")
 | 
			
		||||
 | 
			
		||||
    if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
 | 
			
		||||
      list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
 | 
			
		||||
      list(APPEND GPU_FLAGS "-DENABLE_FP8")
 | 
			
		||||
    endif()
 | 
			
		||||
    if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
 | 
			
		||||
      list(REMOVE_ITEM GPU_FLAGS
 | 
			
		||||
        "-D__CUDA_NO_HALF_OPERATORS__"
 | 
			
		||||
        "-D__CUDA_NO_HALF_CONVERSIONS__"
 | 
			
		||||
        "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
 | 
			
		||||
        "-D__CUDA_NO_HALF2_OPERATORS__")
 | 
			
		||||
    endif()
 | 
			
		||||
 | 
			
		||||
  elseif(${GPU_LANG} STREQUAL "HIP")
 | 
			
		||||
@ -112,6 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
 | 
			
		||||
 | 
			
		||||
    list(APPEND GPU_FLAGS
 | 
			
		||||
      "-DUSE_ROCM"
 | 
			
		||||
      "-DENABLE_FP8"
 | 
			
		||||
      "-U__HIP_NO_HALF_CONVERSIONS__"
 | 
			
		||||
      "-U__HIP_NO_HALF_OPERATORS__"
 | 
			
		||||
      "-fno-gpu-rdc")
 | 
			
		||||
@ -286,6 +294,7 @@ endmacro()
 | 
			
		||||
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
 | 
			
		||||
# LIBRARIES <libraries>      - Extra link libraries.
 | 
			
		||||
# WITH_SOABI                 - Generate library with python SOABI suffix name.
 | 
			
		||||
# USE_SABI <version>         - Use python stable api <version>
 | 
			
		||||
#
 | 
			
		||||
# Note: optimization level/debug info is set via cmake build type.
 | 
			
		||||
#
 | 
			
		||||
@ -293,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
 | 
			
		||||
  cmake_parse_arguments(PARSE_ARGV 1
 | 
			
		||||
    GPU
 | 
			
		||||
    "WITH_SOABI"
 | 
			
		||||
    "DESTINATION;LANGUAGE"
 | 
			
		||||
    "DESTINATION;LANGUAGE;USE_SABI"
 | 
			
		||||
    "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
 | 
			
		||||
 | 
			
		||||
  # Add hipify preprocessing step when building with HIP/ROCm.
 | 
			
		||||
@ -307,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME)
 | 
			
		||||
    set(GPU_WITH_SOABI)
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
 | 
			
		||||
  if (GPU_USE_SABI)
 | 
			
		||||
    Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
 | 
			
		||||
  else()
 | 
			
		||||
    Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  if (GPU_LANGUAGE STREQUAL "HIP")
 | 
			
		||||
    # Make this target dependent on the hipify preprocessor step.
 | 
			
		||||
 | 
			
		||||
@ -63,6 +63,8 @@ DEFAULT_CONDA_PATTERNS = {
 | 
			
		||||
    "magma",
 | 
			
		||||
    "triton",
 | 
			
		||||
    "optree",
 | 
			
		||||
    "nccl",
 | 
			
		||||
    "transformers",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DEFAULT_PIP_PATTERNS = {
 | 
			
		||||
@ -73,6 +75,8 @@ DEFAULT_PIP_PATTERNS = {
 | 
			
		||||
    "triton",
 | 
			
		||||
    "optree",
 | 
			
		||||
    "onnx",
 | 
			
		||||
    "nccl",
 | 
			
		||||
    "transformers",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -599,6 +603,11 @@ Versions of relevant libraries:
 | 
			
		||||
{conda_packages}
 | 
			
		||||
""".strip()
 | 
			
		||||
 | 
			
		||||
# both the above code and the following code use `strip()` to
 | 
			
		||||
# remove leading/trailing whitespaces, so we need to add a newline
 | 
			
		||||
# in between to separate the two sections
 | 
			
		||||
env_info_fmt += "\n"
 | 
			
		||||
 | 
			
		||||
env_info_fmt += """
 | 
			
		||||
ROCM Version: {rocm_version}
 | 
			
		||||
Neuron SDK Version: {neuron_sdk_version}
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include <cmath>
 | 
			
		||||
@ -10,11 +10,11 @@
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
// Activation and gating kernel template.
 | 
			
		||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 | 
			
		||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 | 
			
		||||
__global__ void act_and_mul_kernel(
 | 
			
		||||
  scalar_t* __restrict__ out,               // [..., d]
 | 
			
		||||
  const scalar_t* __restrict__ input,       // [..., 2, d]
 | 
			
		||||
  const int d) {
 | 
			
		||||
    scalar_t* __restrict__ out,          // [..., d]
 | 
			
		||||
    const scalar_t* __restrict__ input,  // [..., 2, d]
 | 
			
		||||
    const int d) {
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
 | 
			
		||||
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
 | 
			
		||||
@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
__device__ __forceinline__ T silu_kernel(const T& x) {
 | 
			
		||||
  // x * sigmoid(x)
 | 
			
		||||
  return (T) (((float) x) / (1.0f + expf((float) -x)));
 | 
			
		||||
  return (T)(((float)x) / (1.0f + expf((float)-x)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
 | 
			
		||||
  // Equivalent to PyTorch GELU with 'none' approximation.
 | 
			
		||||
  // Refer to:
 | 
			
		||||
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
 | 
			
		||||
  const float f = (float) x;
 | 
			
		||||
  const float f = (float)x;
 | 
			
		||||
  constexpr float ALPHA = M_SQRT1_2;
 | 
			
		||||
  return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
 | 
			
		||||
  return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
 | 
			
		||||
  // Equivalent to PyTorch GELU with 'tanh' approximation.
 | 
			
		||||
  // Refer to:
 | 
			
		||||
  // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
 | 
			
		||||
  const float f = (float) x;
 | 
			
		||||
  const float f = (float)x;
 | 
			
		||||
  constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
 | 
			
		||||
  constexpr float KAPPA = 0.044715;
 | 
			
		||||
  float x_cube = f * f * f;
 | 
			
		||||
  float inner = BETA * (f + KAPPA * x_cube);
 | 
			
		||||
  return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
 | 
			
		||||
  return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
// Launch activation and gating kernel.
 | 
			
		||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                                             \
 | 
			
		||||
  int d = input.size(-1) / 2;                                                             \
 | 
			
		||||
  int64_t num_tokens = input.numel() / input.size(-1);                                    \
 | 
			
		||||
  dim3 grid(num_tokens);                                                                  \
 | 
			
		||||
  dim3 block(std::min(d, 1024));                                                          \
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
 | 
			
		||||
    input.scalar_type(),                                                                  \
 | 
			
		||||
    "act_and_mul_kernel",                                                                 \
 | 
			
		||||
    [&] {                                                                                 \
 | 
			
		||||
      vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(   \
 | 
			
		||||
        out.data_ptr<scalar_t>(),                                                         \
 | 
			
		||||
        input.data_ptr<scalar_t>(),                                                       \
 | 
			
		||||
        d);                                                                               \
 | 
			
		||||
    });
 | 
			
		||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL)                            \
 | 
			
		||||
  int d = input.size(-1) / 2;                                            \
 | 
			
		||||
  int64_t num_tokens = input.numel() / input.size(-1);                   \
 | 
			
		||||
  dim3 grid(num_tokens);                                                 \
 | 
			
		||||
  dim3 block(std::min(d, 1024));                                         \
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));      \
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();          \
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(                                          \
 | 
			
		||||
      input.scalar_type(), "act_and_mul_kernel", [&] {                   \
 | 
			
		||||
        vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>>             \
 | 
			
		||||
            <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),       \
 | 
			
		||||
                                         input.data_ptr<scalar_t>(), d); \
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
void silu_and_mul(
 | 
			
		||||
  torch::Tensor& out,      // [..., d]
 | 
			
		||||
  torch::Tensor& input)    // [..., 2 * d]
 | 
			
		||||
void silu_and_mul(torch::Tensor& out,    // [..., d]
 | 
			
		||||
                  torch::Tensor& input)  // [..., 2 * d]
 | 
			
		||||
{
 | 
			
		||||
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gelu_and_mul(
 | 
			
		||||
  torch::Tensor& out,      // [..., d]
 | 
			
		||||
  torch::Tensor& input)    // [..., 2 * d]
 | 
			
		||||
void gelu_and_mul(torch::Tensor& out,    // [..., d]
 | 
			
		||||
                  torch::Tensor& input)  // [..., 2 * d]
 | 
			
		||||
{
 | 
			
		||||
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gelu_tanh_and_mul(
 | 
			
		||||
  torch::Tensor& out,      // [..., d]
 | 
			
		||||
  torch::Tensor& input)    // [..., 2 * d]
 | 
			
		||||
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
 | 
			
		||||
                       torch::Tensor& input)  // [..., 2 * d]
 | 
			
		||||
{
 | 
			
		||||
  LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
 | 
			
		||||
}
 | 
			
		||||
@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
// Element-wise activation kernel template.
 | 
			
		||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 | 
			
		||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
 | 
			
		||||
__global__ void activation_kernel(
 | 
			
		||||
  scalar_t* __restrict__ out,               // [..., d]
 | 
			
		||||
  const scalar_t* __restrict__ input,       // [..., d]
 | 
			
		||||
  const int d) {
 | 
			
		||||
    scalar_t* __restrict__ out,          // [..., d]
 | 
			
		||||
    const scalar_t* __restrict__ input,  // [..., d]
 | 
			
		||||
    const int d) {
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
 | 
			
		||||
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
 | 
			
		||||
@ -108,54 +102,49 @@ __global__ void activation_kernel(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
// Launch element-wise activation kernel.
 | 
			
		||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                                  \
 | 
			
		||||
  int d = input.size(-1);                                                                 \
 | 
			
		||||
  int64_t num_tokens = input.numel() / d;                                                 \
 | 
			
		||||
  dim3 grid(num_tokens);                                                                  \
 | 
			
		||||
  dim3 block(std::min(d, 1024));                                                          \
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
 | 
			
		||||
    input.scalar_type(),                                                                  \
 | 
			
		||||
    "activation_kernel",                                                                  \
 | 
			
		||||
    [&] {                                                                                 \
 | 
			
		||||
      vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>(    \
 | 
			
		||||
        out.data_ptr<scalar_t>(),                                                         \
 | 
			
		||||
        input.data_ptr<scalar_t>(),                                                       \
 | 
			
		||||
        d);                                                                               \
 | 
			
		||||
    });
 | 
			
		||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL)                                       \
 | 
			
		||||
  int d = input.size(-1);                                                      \
 | 
			
		||||
  int64_t num_tokens = input.numel() / d;                                      \
 | 
			
		||||
  dim3 grid(num_tokens);                                                       \
 | 
			
		||||
  dim3 block(std::min(d, 1024));                                               \
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));            \
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                \
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
 | 
			
		||||
    vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>                        \
 | 
			
		||||
        <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),                 \
 | 
			
		||||
                                     input.data_ptr<scalar_t>(), d);           \
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
 | 
			
		||||
  const float x3 = (float) (x * x * x);
 | 
			
		||||
  const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
 | 
			
		||||
  return ((T) 0.5) * x * (((T) 1.0) + t);
 | 
			
		||||
  const float x3 = (float)(x * x * x);
 | 
			
		||||
  const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
 | 
			
		||||
  return ((T)0.5) * x * (((T)1.0) + t);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
 | 
			
		||||
  const float f = (float) x;
 | 
			
		||||
  const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
 | 
			
		||||
  return ((T) 0.5) * x * (((T) 1.0) + t);
 | 
			
		||||
  const float f = (float)x;
 | 
			
		||||
  const T t =
 | 
			
		||||
      (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
 | 
			
		||||
  return ((T)0.5) * x * (((T)1.0) + t);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
void gelu_new(
 | 
			
		||||
  torch::Tensor& out,     // [..., d]
 | 
			
		||||
  torch::Tensor& input)   // [..., d]
 | 
			
		||||
void gelu_new(torch::Tensor& out,    // [..., d]
 | 
			
		||||
              torch::Tensor& input)  // [..., d]
 | 
			
		||||
{
 | 
			
		||||
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gelu_fast(
 | 
			
		||||
  torch::Tensor& out,     // [..., d]
 | 
			
		||||
  torch::Tensor& input)   // [..., d]
 | 
			
		||||
void gelu_fast(torch::Tensor& out,    // [..., d]
 | 
			
		||||
               torch::Tensor& input)  // [..., d]
 | 
			
		||||
{
 | 
			
		||||
  LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -4,4 +4,4 @@
 | 
			
		||||
#include "dtype_float16.cuh"
 | 
			
		||||
#include "dtype_float32.cuh"
 | 
			
		||||
#include "dtype_bfloat16.cuh"
 | 
			
		||||
#include "dtype_fp8_e5m2.cuh"
 | 
			
		||||
#include "dtype_fp8.cuh"
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Adapted from
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Copyright (c) 2023, The vLLM team.
 | 
			
		||||
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
@ -22,31 +23,31 @@
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
// A vector type to store Q, K, V elements.
 | 
			
		||||
template<typename T, int VEC_SIZE>
 | 
			
		||||
template <typename T, int VEC_SIZE>
 | 
			
		||||
struct Vec {};
 | 
			
		||||
 | 
			
		||||
// A vector type to store FP32 accumulators.
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct FloatVec {};
 | 
			
		||||
 | 
			
		||||
// Template vector operations.
 | 
			
		||||
template<typename Acc, typename A, typename B>
 | 
			
		||||
template <typename Acc, typename A, typename B>
 | 
			
		||||
inline __device__ Acc mul(A a, B b);
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
inline __device__ float sum(T v);
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
inline __device__ float dot(T a, T b) {
 | 
			
		||||
  return sum(mul<T, T, T>(a, b));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename A, typename T>
 | 
			
		||||
template <typename A, typename T>
 | 
			
		||||
inline __device__ float dot(T a, T b) {
 | 
			
		||||
  return sum(mul<A, T, T>(a, b));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
template <typename T>
 | 
			
		||||
inline __device__ void zero(T& dst) {
 | 
			
		||||
  constexpr int WORDS = sizeof(T) / 4;
 | 
			
		||||
  union {
 | 
			
		||||
@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
 | 
			
		||||
  dst = tmp.raw;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -1,5 +1,6 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * Adapted from
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * Copyright (c) 2023, The vLLM team.
 | 
			
		||||
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
@ -26,7 +27,7 @@
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
// Q*K^T operation.
 | 
			
		||||
template<int THREAD_GROUP_SIZE, typename Vec, int N>
 | 
			
		||||
template <int THREAD_GROUP_SIZE, typename Vec, int N>
 | 
			
		||||
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
 | 
			
		||||
  using A_vec = typename FloatVec<Vec>::Type;
 | 
			
		||||
  // Compute the parallel products for Q*K^T (treat vector lanes separately).
 | 
			
		||||
@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
 | 
			
		||||
  return qk;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename T, int THREAD_GROUP_SIZE>
 | 
			
		||||
template <typename T, int THREAD_GROUP_SIZE>
 | 
			
		||||
struct Qk_dot {
 | 
			
		||||
  template<typename Vec, int N>
 | 
			
		||||
  template <typename Vec, int N>
 | 
			
		||||
  static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
 | 
			
		||||
    return qk_dot_<THREAD_GROUP_SIZE>(q, k);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Adapted from
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * and
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Copyright (c) 2023, The vLLM team.
 | 
			
		||||
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
@ -28,8 +30,8 @@
 | 
			
		||||
  #include <hip/hip_bf16.h>
 | 
			
		||||
  #include <hip/hip_fp16.h>
 | 
			
		||||
 | 
			
		||||
  typedef __hip_bfloat162 __nv_bfloat162;
 | 
			
		||||
  typedef __hip_bfloat16 __nv_bfloat16;
 | 
			
		||||
typedef __hip_bfloat162 __nv_bfloat162;
 | 
			
		||||
typedef __hip_bfloat16 __nv_bfloat16;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
@ -50,37 +52,37 @@ struct bf16_8_t {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// BF16 vector types for Q, K, V.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<__nv_bfloat16, 1> {
 | 
			
		||||
  using Type = __nv_bfloat16;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<__nv_bfloat16, 2> {
 | 
			
		||||
  using Type = __nv_bfloat162;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<__nv_bfloat16, 4> {
 | 
			
		||||
  using Type = bf16_4_t;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<__nv_bfloat16, 8> {
 | 
			
		||||
  using Type = bf16_8_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// FP32 accumulator vector types corresponding to Vec.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<__nv_bfloat16> {
 | 
			
		||||
  using Type = float;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<__nv_bfloat162> {
 | 
			
		||||
  using Type = float2;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<bf16_4_t> {
 | 
			
		||||
  using Type = Float4_;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<bf16_8_t> {
 | 
			
		||||
  using Type = Float8_;
 | 
			
		||||
};
 | 
			
		||||
@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
 | 
			
		||||
  assert(false);
 | 
			
		||||
#else
 | 
			
		||||
  #ifndef USE_ROCM
 | 
			
		||||
    return a + b;
 | 
			
		||||
  return a + b;
 | 
			
		||||
  #else
 | 
			
		||||
    return __hadd(a, b);
 | 
			
		||||
  return __hadd(a, b);
 | 
			
		||||
  #endif
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector multiplication.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 | 
			
		||||
  assert(false);
 | 
			
		||||
@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 | 
			
		||||
  assert(false);
 | 
			
		||||
@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
 | 
			
		||||
  return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
 | 
			
		||||
  bf16_4_t c;
 | 
			
		||||
  c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
 | 
			
		||||
@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
 | 
			
		||||
  __nv_bfloat162 s = bf162bf162(a);
 | 
			
		||||
  bf16_4_t c;
 | 
			
		||||
@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
 | 
			
		||||
  bf16_8_t c;
 | 
			
		||||
  c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
 | 
			
		||||
@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
 | 
			
		||||
  __nv_bfloat162 s = bf162bf162(a);
 | 
			
		||||
  bf16_8_t c;
 | 
			
		||||
@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
 | 
			
		||||
  float fa = __bfloat162float(a);
 | 
			
		||||
  float fb = __bfloat162float(b);
 | 
			
		||||
  return fa * fb;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
 | 
			
		||||
  float2 fa = bf1622float2(a);
 | 
			
		||||
  float2 fb = bf1622float2(b);
 | 
			
		||||
  return mul<float2, float2, float2>(fa, fb);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
 | 
			
		||||
  return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
 | 
			
		||||
  Float4_ fc;
 | 
			
		||||
  fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
 | 
			
		||||
@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
 | 
			
		||||
  return fc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
 | 
			
		||||
  __nv_bfloat162 s = bf162bf162(a);
 | 
			
		||||
  Float4_ fc;
 | 
			
		||||
@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
 | 
			
		||||
  return fc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
 | 
			
		||||
  Float8_ fc;
 | 
			
		||||
  fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
 | 
			
		||||
@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
 | 
			
		||||
  return fc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
 | 
			
		||||
  __nv_bfloat162 s = bf162bf162(a);
 | 
			
		||||
  Float8_ fc;
 | 
			
		||||
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector fused multiply-add.
 | 
			
		||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
 | 
			
		||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
 | 
			
		||||
                                     __nv_bfloat162 c) {
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 | 
			
		||||
  assert(false);
 | 
			
		||||
#else
 | 
			
		||||
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
 | 
			
		||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
 | 
			
		||||
                                     __nv_bfloat162 c) {
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 | 
			
		||||
  assert(false);
 | 
			
		||||
#else
 | 
			
		||||
@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector sum.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(__nv_bfloat16 v) {
 | 
			
		||||
  return __bfloat162float(v);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(__nv_bfloat162 v) {
 | 
			
		||||
  float2 vf = bf1622float2(v);
 | 
			
		||||
  return vf.x + vf.y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(bf16_4_t v) {
 | 
			
		||||
  return sum(v.x) + sum(v.y);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(bf16_8_t v) {
 | 
			
		||||
  return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
 | 
			
		||||
}
 | 
			
		||||
@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Adapted from
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * and
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Copyright (c) 2023, The vLLM team.
 | 
			
		||||
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
@ -30,37 +32,37 @@
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
// FP16 vector types for Q, K, V.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint16_t, 1> {
 | 
			
		||||
  using Type = uint16_t;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint16_t, 2> {
 | 
			
		||||
  using Type = uint32_t;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint16_t, 4> {
 | 
			
		||||
  using Type = uint2;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint16_t, 8> {
 | 
			
		||||
  using Type = uint4;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// FP32 accumulator vector types corresponding to Vec.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<uint16_t> {
 | 
			
		||||
  using Type = float;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<uint32_t> {
 | 
			
		||||
  using Type = float2;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<uint2> {
 | 
			
		||||
  using Type = Float4_;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<uint4> {
 | 
			
		||||
  using Type = Float8_;
 | 
			
		||||
};
 | 
			
		||||
@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
 | 
			
		||||
  return b;
 | 
			
		||||
#else
 | 
			
		||||
  union {
 | 
			
		||||
   uint32_t u32;
 | 
			
		||||
   uint16_t u16[2];
 | 
			
		||||
    uint32_t u32;
 | 
			
		||||
    uint16_t u16[2];
 | 
			
		||||
  } tmp;
 | 
			
		||||
  tmp.u16[0] = a;
 | 
			
		||||
  tmp.u16[1] = a;
 | 
			
		||||
@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
 | 
			
		||||
  } tmp;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 | 
			
		||||
    asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
 | 
			
		||||
  asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
 | 
			
		||||
               : "=r"(tmp.u32)
 | 
			
		||||
               : "f"(f.y), "f"(f.x));
 | 
			
		||||
  #else
 | 
			
		||||
    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
 | 
			
		||||
    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
 | 
			
		||||
  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
 | 
			
		||||
  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
 | 
			
		||||
  #endif
 | 
			
		||||
#else
 | 
			
		||||
  tmp.u16[0] = float_to_half(f.x);
 | 
			
		||||
@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector multiplication.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
 | 
			
		||||
  uint16_t c;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
 | 
			
		||||
  uint32_t c;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
 | 
			
		||||
  return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ uint2 mul(uint2 a, uint2 b) {
 | 
			
		||||
  uint2 c;
 | 
			
		||||
  c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
 | 
			
		||||
@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ uint2 mul(uint16_t a, uint2 b) {
 | 
			
		||||
  uint32_t s = h0_h0(a);
 | 
			
		||||
  uint2 c;
 | 
			
		||||
@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ uint4 mul(uint4 a, uint4 b) {
 | 
			
		||||
  uint4 c;
 | 
			
		||||
  c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
 | 
			
		||||
@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ uint4 mul(uint16_t a, uint4 b) {
 | 
			
		||||
  uint32_t s = h0_h0(a);
 | 
			
		||||
  uint4 c;
 | 
			
		||||
@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float mul(uint16_t a, uint16_t b) {
 | 
			
		||||
  float fa = half_to_float(a);
 | 
			
		||||
  float fb = half_to_float(b);
 | 
			
		||||
  return fa * fb;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float2 mul(uint32_t a, uint32_t b) {
 | 
			
		||||
  float2 fa = half2_to_float2(a);
 | 
			
		||||
  float2 fb = half2_to_float2(b);
 | 
			
		||||
  return mul<float2, float2, float2>(fa, fb);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float2 mul(uint16_t a, uint32_t b) {
 | 
			
		||||
  return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float4_ mul(uint2 a, uint2 b) {
 | 
			
		||||
  Float4_ fc;
 | 
			
		||||
  fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
 | 
			
		||||
@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
 | 
			
		||||
  return fc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
 | 
			
		||||
  uint32_t s = h0_h0(a);
 | 
			
		||||
  Float4_ fc;
 | 
			
		||||
@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
 | 
			
		||||
  return fc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float8_ mul(uint4 a, uint4 b) {
 | 
			
		||||
  Float8_ fc;
 | 
			
		||||
  fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
 | 
			
		||||
@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
 | 
			
		||||
  return fc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
 | 
			
		||||
  uint32_t s = h0_h0(a);
 | 
			
		||||
  Float8_ fc;
 | 
			
		||||
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
 | 
			
		||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
 | 
			
		||||
  uint32_t d;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
 | 
			
		||||
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
 | 
			
		||||
               : "=r"(d)
 | 
			
		||||
               : "r"(a), "r"(b), "r"(c));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
 | 
			
		||||
  asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
 | 
			
		||||
               : "=v"(d)
 | 
			
		||||
               : "v"(a), "v"(b), "v"(c));
 | 
			
		||||
#endif
 | 
			
		||||
  return d;
 | 
			
		||||
}
 | 
			
		||||
@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector sum.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(uint16_t v) {
 | 
			
		||||
  return half_to_float(v);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(uint32_t v) {
 | 
			
		||||
  float2 tmp = half2_to_float2(v);
 | 
			
		||||
  return tmp.x + tmp.y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(uint2 v) {
 | 
			
		||||
  uint32_t c = add(v.x, v.y);
 | 
			
		||||
  return sum(c);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(uint4 v) {
 | 
			
		||||
  uint32_t c = add(v.x, v.y);
 | 
			
		||||
  c = add(c, v.z);
 | 
			
		||||
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// From float16 to float32.
 | 
			
		||||
inline __device__ float to_float(uint16_t u) {
 | 
			
		||||
  return half_to_float(u);
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
 | 
			
		||||
 | 
			
		||||
inline __device__ float2 to_float(uint32_t u) {
 | 
			
		||||
  return half2_to_float2(u);
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
 | 
			
		||||
 | 
			
		||||
inline __device__ Float4_ to_float(uint2 u) {
 | 
			
		||||
  Float4_ tmp;
 | 
			
		||||
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Zero-out a variable.
 | 
			
		||||
inline __device__ void zero(uint16_t& dst) {
 | 
			
		||||
  dst = uint16_t(0);
 | 
			
		||||
}
 | 
			
		||||
inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Adapted from
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
 | 
			
		||||
 * and
 | 
			
		||||
 * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
 | 
			
		||||
 * Copyright (c) 2023, The vLLM team.
 | 
			
		||||
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 | 
			
		||||
 *
 | 
			
		||||
@ -38,37 +40,35 @@ struct Float8_ {
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// FP32 vector types for Q, K, V.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<float, 1> {
 | 
			
		||||
  using Type = float;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<float, 2> {
 | 
			
		||||
  using Type = float2;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<float, 4> {
 | 
			
		||||
  using Type = float4;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// FP32 accumulator vector types corresponding to Vec.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<float> {
 | 
			
		||||
  using Type = float;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<float2> {
 | 
			
		||||
  using Type = float2;
 | 
			
		||||
};
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
struct FloatVec<float4> {
 | 
			
		||||
  using Type = float4;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Vector addition.
 | 
			
		||||
inline __device__ float add(float a, float b) {
 | 
			
		||||
  return a + b;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float add(float a, float b) { return a + b; }
 | 
			
		||||
 | 
			
		||||
inline __device__ float2 add(float2 a, float2 b) {
 | 
			
		||||
  float2 c;
 | 
			
		||||
@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector multiplication.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float mul<float, float>(float a, float b) {
 | 
			
		||||
  return a * b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float2 mul(float2 a, float2 b) {
 | 
			
		||||
  float2 c;
 | 
			
		||||
  c.x = a.x * b.x;
 | 
			
		||||
@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float2 mul(float a, float2 b) {
 | 
			
		||||
  float2 c;
 | 
			
		||||
  c.x = a * b.x;
 | 
			
		||||
@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float4 mul(float4 a, float4 b) {
 | 
			
		||||
  float4 c;
 | 
			
		||||
  c.x = a.x * b.x;
 | 
			
		||||
@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float4 mul(float a, float4 b) {
 | 
			
		||||
  float4 c;
 | 
			
		||||
  c.x = a * b.x;
 | 
			
		||||
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector fused multiply-add.
 | 
			
		||||
inline __device__ float fma(float a, float b, float c) {
 | 
			
		||||
  return a * b + c;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float fma(float a, float b, float c) { return a * b + c; }
 | 
			
		||||
 | 
			
		||||
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
 | 
			
		||||
  float2 d;
 | 
			
		||||
@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector sum.
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(float v) {
 | 
			
		||||
  return v;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(float2 v) {
 | 
			
		||||
  return v.x + v.y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(float4 v) {
 | 
			
		||||
  return v.x + v.y + v.z + v.w;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(Float4_ v) {
 | 
			
		||||
  return v.x.x + v.x.y + v.y.x + v.y.y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
template <>
 | 
			
		||||
inline __device__ float sum(Float8_ v) {
 | 
			
		||||
  return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector dot product.
 | 
			
		||||
inline __device__ float dot(float a, float b) {
 | 
			
		||||
  return a * b;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float dot(float a, float b) { return a * b; }
 | 
			
		||||
 | 
			
		||||
inline __device__ float dot(float2 a, float2 b) {
 | 
			
		||||
  float2 c = mul<float2, float2, float2>(a, b);
 | 
			
		||||
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// From float to float.
 | 
			
		||||
inline __device__ void from_float(float& dst, float src) {
 | 
			
		||||
  dst = src;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ void from_float(float& dst, float src) { dst = src; }
 | 
			
		||||
 | 
			
		||||
inline __device__ void from_float(float2& dst, float2 src) {
 | 
			
		||||
  dst = src;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
 | 
			
		||||
 | 
			
		||||
inline __device__ void from_float(float4& dst, float4 src) {
 | 
			
		||||
  dst = src;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
 | 
			
		||||
 | 
			
		||||
// From float to float.
 | 
			
		||||
inline __device__ float to_float(float u) {
 | 
			
		||||
  return u;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float to_float(float u) { return u; }
 | 
			
		||||
 | 
			
		||||
inline __device__ float2 to_float(float2 u) {
 | 
			
		||||
  return u;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float2 to_float(float2 u) { return u; }
 | 
			
		||||
 | 
			
		||||
inline __device__ float4 to_float(float4 u) {
 | 
			
		||||
  return u;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ float4 to_float(float4 u) { return u; }
 | 
			
		||||
 | 
			
		||||
inline __device__ Float4_ to_float(Float4_ u) {
 | 
			
		||||
  return u;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ Float4_ to_float(Float4_ u) { return u; }
 | 
			
		||||
 | 
			
		||||
inline __device__ Float8_ to_float(Float8_ u) {
 | 
			
		||||
  return u;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ Float8_ to_float(Float8_ u) { return u; }
 | 
			
		||||
 | 
			
		||||
// Zero-out a variable.
 | 
			
		||||
inline __device__ void zero(float& dst) {
 | 
			
		||||
  dst = 0.f;
 | 
			
		||||
}
 | 
			
		||||
inline __device__ void zero(float& dst) { dst = 0.f; }
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										41
									
								
								csrc/attention/dtype_fp8.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								csrc/attention/dtype_fp8.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,41 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "attention_generic.cuh"
 | 
			
		||||
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
#ifdef ENABLE_FP8
 | 
			
		||||
  #ifndef USE_ROCM
 | 
			
		||||
    #include <cuda_fp8.h>
 | 
			
		||||
  #endif  // USE_ROCM
 | 
			
		||||
#endif    // ENABLE_FP8
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
enum class Fp8KVCacheDataType {
 | 
			
		||||
  kAuto = 0,
 | 
			
		||||
  kFp8E4M3 = 1,
 | 
			
		||||
  kFp8E5M2 = 2,
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// fp8 vector types for quantization of kv cache
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint8_t, 1> {
 | 
			
		||||
  using Type = uint8_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint8_t, 2> {
 | 
			
		||||
  using Type = uint16_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint8_t, 4> {
 | 
			
		||||
  using Type = uint32_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct Vec<uint8_t, 8> {
 | 
			
		||||
  using Type = uint2;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
@ -1,35 +0,0 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "attention_generic.cuh"
 | 
			
		||||
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
#include <cuda_fp8.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
// fp8 vector types for quantization of kv cache
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 1> {
 | 
			
		||||
    using Type = uint8_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 2> {
 | 
			
		||||
    using Type = uint16_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 4> {
 | 
			
		||||
    using Type = uint32_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 8> {
 | 
			
		||||
    using Type = uint2;
 | 
			
		||||
};
 | 
			
		||||
#endif // ENABLE_FP8_E5M2
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
							
								
								
									
										41
									
								
								csrc/cache.h
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								csrc/cache.h
									
									
									
									
									
								
							@ -1,29 +1,32 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
void swap_blocks(
 | 
			
		||||
  torch::Tensor& src,
 | 
			
		||||
  torch::Tensor& dst,
 | 
			
		||||
  const std::map<int64_t, int64_t>& block_mapping);
 | 
			
		||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
 | 
			
		||||
                 const torch::Tensor& block_mapping);
 | 
			
		||||
 | 
			
		||||
void copy_blocks(
 | 
			
		||||
  std::vector<torch::Tensor>& key_caches,
 | 
			
		||||
  std::vector<torch::Tensor>& value_caches,
 | 
			
		||||
  const std::map<int64_t, std::vector<int64_t>>& block_mapping);
 | 
			
		||||
// Note: the key_caches and value_caches vectors are constant but
 | 
			
		||||
// not the Tensors they contain. The vectors need to be const refs
 | 
			
		||||
// in order to satisfy pytorch's C++ operator registration code.
 | 
			
		||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
 | 
			
		||||
                 std::vector<torch::Tensor> const& value_caches,
 | 
			
		||||
                 const torch::Tensor& block_mapping);
 | 
			
		||||
 | 
			
		||||
void reshape_and_cache(
 | 
			
		||||
  torch::Tensor& key,
 | 
			
		||||
  torch::Tensor& value,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  torch::Tensor& slot_mapping,
 | 
			
		||||
  const std::string& kv_cache_dtype);
 | 
			
		||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
 | 
			
		||||
                       torch::Tensor& key_cache, torch::Tensor& value_cache,
 | 
			
		||||
                       torch::Tensor& slot_mapping,
 | 
			
		||||
                       const std::string& kv_cache_dtype,
 | 
			
		||||
                       const double kv_scale);
 | 
			
		||||
 | 
			
		||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
 | 
			
		||||
                             torch::Tensor& key_cache,
 | 
			
		||||
                             torch::Tensor& value_cache,
 | 
			
		||||
                             torch::Tensor& slot_mapping,
 | 
			
		||||
                             const std::string& kv_cache_dtype);
 | 
			
		||||
 | 
			
		||||
// Just for unittest
 | 
			
		||||
void convert_fp8_e5m2(
 | 
			
		||||
  torch::Tensor& src_cache,
 | 
			
		||||
  torch::Tensor& dst_cache);
 | 
			
		||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
 | 
			
		||||
                 const double scale, const std::string& kv_cache_dtype);
 | 
			
		||||
 | 
			
		||||
@ -1,11 +1,14 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include "cuda_compat.h"
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  #include "quantization/fp8/amd/quant_utils.cuh"
 | 
			
		||||
#else
 | 
			
		||||
  #include "quantization/fp8/nvidia/quant_utils.cuh"
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
@ -15,20 +18,17 @@
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  #include <hip/hip_bf16.h>
 | 
			
		||||
  typedef __hip_bfloat16 __nv_bfloat16;
 | 
			
		||||
typedef __hip_bfloat16 __nv_bfloat16;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
void swap_blocks(
 | 
			
		||||
  torch::Tensor& src,
 | 
			
		||||
  torch::Tensor& dst,
 | 
			
		||||
  const std::map<int64_t, int64_t>& block_mapping) {
 | 
			
		||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
 | 
			
		||||
                 const torch::Tensor& block_mapping) {
 | 
			
		||||
  torch::Device src_device = src.device();
 | 
			
		||||
  torch::Device dst_device = dst.device();
 | 
			
		||||
  cudaMemcpyKind memcpy_type;
 | 
			
		||||
  if (src_device.is_cuda() && dst_device.is_cuda()) {
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
      src_device.index() == dst_device.index(),
 | 
			
		||||
      "src and dst must be on the same GPU");
 | 
			
		||||
    TORCH_CHECK(src_device.index() == dst_device.index(),
 | 
			
		||||
                "src and dst must be on the same GPU");
 | 
			
		||||
    memcpy_type = cudaMemcpyDeviceToDevice;
 | 
			
		||||
  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
 | 
			
		||||
    memcpy_type = cudaMemcpyDeviceToHost;
 | 
			
		||||
@ -38,41 +38,44 @@ void swap_blocks(
 | 
			
		||||
    TORCH_CHECK(false, "Invalid device combination");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char *src_ptr = static_cast<char*>(src.data_ptr());
 | 
			
		||||
  char *dst_ptr = static_cast<char*>(dst.data_ptr());
 | 
			
		||||
  // NOTE(youkaichao): keep in mind that `block_mapping` should be
 | 
			
		||||
  // a cpu tensor, otherwise every `item` call will require a gpu-cpu
 | 
			
		||||
  // synchronization.
 | 
			
		||||
  TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
 | 
			
		||||
 | 
			
		||||
  char* src_ptr = static_cast<char*>(src.data_ptr());
 | 
			
		||||
  char* dst_ptr = static_cast<char*>(dst.data_ptr());
 | 
			
		||||
 | 
			
		||||
  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(
 | 
			
		||||
      src_device.is_cuda() ? src_device : dst_device);
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  // NOTE(woosuk): This can be slow if the number of blocks is large.
 | 
			
		||||
  for (const auto& pair : block_mapping) {
 | 
			
		||||
    int64_t src_block_number = pair.first;
 | 
			
		||||
    int64_t dst_block_number = pair.second;
 | 
			
		||||
  const int64_t num_blocks = block_mapping.size(0);
 | 
			
		||||
  for (size_t i = 0; i < num_blocks; i++) {
 | 
			
		||||
    int64_t src_block_number = block_mapping[i][0].item<int64_t>();
 | 
			
		||||
    int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
 | 
			
		||||
    int64_t src_offset = src_block_number * block_size_in_bytes;
 | 
			
		||||
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
 | 
			
		||||
    cudaMemcpyAsync(
 | 
			
		||||
      dst_ptr + dst_offset,
 | 
			
		||||
      src_ptr + src_offset,
 | 
			
		||||
      block_size_in_bytes,
 | 
			
		||||
      memcpy_type,
 | 
			
		||||
      stream);
 | 
			
		||||
    cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
 | 
			
		||||
                    block_size_in_bytes, memcpy_type, stream);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
// Grid: (num_layers, num_pairs)
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
__global__ void copy_blocks_kernel(
 | 
			
		||||
  int64_t* key_cache_ptrs,
 | 
			
		||||
  int64_t* value_cache_ptrs,
 | 
			
		||||
  const int64_t* __restrict__ block_mapping,
 | 
			
		||||
  const int numel_per_block) {
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
 | 
			
		||||
                                   int64_t* value_cache_ptrs,
 | 
			
		||||
                                   const int64_t* __restrict__ block_mapping,
 | 
			
		||||
                                   const int numel_per_block) {
 | 
			
		||||
  const int layer_idx = blockIdx.x;
 | 
			
		||||
  const int pair_idx = blockIdx.y;
 | 
			
		||||
 | 
			
		||||
  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
 | 
			
		||||
  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
 | 
			
		||||
  scalar_t* value_cache =
 | 
			
		||||
      reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
 | 
			
		||||
  int64_t src_block_number = block_mapping[2 * pair_idx];
 | 
			
		||||
  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
 | 
			
		||||
 | 
			
		||||
@ -90,12 +93,14 @@ __global__ void copy_blocks_kernel(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
void copy_blocks(
 | 
			
		||||
  std::vector<torch::Tensor>& key_caches,
 | 
			
		||||
  std::vector<torch::Tensor>& value_caches,
 | 
			
		||||
  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
 | 
			
		||||
// Note: the key_caches and value_caches vectors are constant but
 | 
			
		||||
// not the Tensors they contain. The vectors need to be const refs
 | 
			
		||||
// in order to satisfy pytorch's C++ operator registration code.
 | 
			
		||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
 | 
			
		||||
                 std::vector<torch::Tensor> const& value_caches,
 | 
			
		||||
                 const torch::Tensor& block_mapping) {
 | 
			
		||||
  int num_layers = key_caches.size();
 | 
			
		||||
  TORCH_CHECK(num_layers == value_caches.size());
 | 
			
		||||
  if (num_layers == 0) {
 | 
			
		||||
@ -109,29 +114,23 @@ void copy_blocks(
 | 
			
		||||
  int64_t key_cache_ptrs[num_layers];
 | 
			
		||||
  int64_t value_cache_ptrs[num_layers];
 | 
			
		||||
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
 | 
			
		||||
    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
 | 
			
		||||
    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
 | 
			
		||||
    key_cache_ptrs[layer_idx] =
 | 
			
		||||
        reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
 | 
			
		||||
    value_cache_ptrs[layer_idx] =
 | 
			
		||||
        reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
 | 
			
		||||
  }
 | 
			
		||||
  // Create block mapping array.
 | 
			
		||||
  std::vector<int64_t> block_mapping_vec;
 | 
			
		||||
  for (const auto& pair : block_mapping) {
 | 
			
		||||
    int64_t src_block_number = pair.first;
 | 
			
		||||
    for (int64_t dst_block_number : pair.second) {
 | 
			
		||||
      block_mapping_vec.push_back(src_block_number);
 | 
			
		||||
      block_mapping_vec.push_back(dst_block_number);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  int64_t* block_mapping_array = block_mapping_vec.data();
 | 
			
		||||
  int num_pairs = block_mapping_vec.size() / 2;
 | 
			
		||||
 | 
			
		||||
  // block_mapping is a 2D tensor with shape (num_pairs, 2).
 | 
			
		||||
  int num_pairs = block_mapping.size(0);
 | 
			
		||||
 | 
			
		||||
  // Move the data structures to the GPU.
 | 
			
		||||
  // NOTE: This synchronizes the CPU and GPU.
 | 
			
		||||
  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
 | 
			
		||||
    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
 | 
			
		||||
  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
 | 
			
		||||
    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
 | 
			
		||||
  torch::Tensor block_mapping_tensor = torch::from_blob(
 | 
			
		||||
    block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
 | 
			
		||||
  torch::Tensor key_cache_ptrs_tensor =
 | 
			
		||||
      torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
 | 
			
		||||
          .to(cache_device);
 | 
			
		||||
  torch::Tensor value_cache_ptrs_tensor =
 | 
			
		||||
      torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
 | 
			
		||||
          .to(cache_device);
 | 
			
		||||
 | 
			
		||||
  // Launch the kernel.
 | 
			
		||||
  const int numel_per_block = key_caches[0][0].numel();
 | 
			
		||||
@ -140,30 +139,28 @@ void copy_blocks(
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
 | 
			
		||||
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
 | 
			
		||||
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
 | 
			
		||||
        value_cache_ptrs_tensor.data_ptr<int64_t>(),
 | 
			
		||||
        block_mapping_tensor.data_ptr<int64_t>(),
 | 
			
		||||
        numel_per_block);
 | 
			
		||||
    }));
 | 
			
		||||
      key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
 | 
			
		||||
        vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
            key_cache_ptrs_tensor.data_ptr<int64_t>(),
 | 
			
		||||
            value_cache_ptrs_tensor.data_ptr<int64_t>(),
 | 
			
		||||
            block_mapping.data_ptr<int64_t>(), numel_per_block);
 | 
			
		||||
      }));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
 | 
			
		||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
 | 
			
		||||
__global__ void reshape_and_cache_kernel(
 | 
			
		||||
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
 | 
			
		||||
  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
 | 
			
		||||
  const int key_stride,
 | 
			
		||||
  const int value_stride,
 | 
			
		||||
  const int num_heads,
 | 
			
		||||
  const int head_size,
 | 
			
		||||
  const int block_size,
 | 
			
		||||
  const int x) {
 | 
			
		||||
    const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
 | 
			
		||||
    const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
 | 
			
		||||
    cache_t* __restrict__ key_cache,     // [num_blocks, num_heads, head_size/x,
 | 
			
		||||
                                         // block_size, x]
 | 
			
		||||
    cache_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size,
 | 
			
		||||
                                         // block_size]
 | 
			
		||||
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
 | 
			
		||||
    const int key_stride, const int value_stride, const int num_heads,
 | 
			
		||||
    const int head_size, const int block_size, const int x,
 | 
			
		||||
    const float kv_scale) {
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  const int64_t slot_idx = slot_mapping[token_idx];
 | 
			
		||||
  if (slot_idx < 0) {
 | 
			
		||||
@ -184,55 +181,84 @@ __global__ void reshape_and_cache_kernel(
 | 
			
		||||
    const int x_idx = head_offset / x;
 | 
			
		||||
    const int x_offset = head_offset % x;
 | 
			
		||||
 | 
			
		||||
    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
 | 
			
		||||
                                + head_idx * (head_size / x) * block_size * x
 | 
			
		||||
                                + x_idx * block_size * x
 | 
			
		||||
                                + block_offset * x
 | 
			
		||||
                                + x_offset;
 | 
			
		||||
    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
 | 
			
		||||
                                  + head_idx * head_size * block_size
 | 
			
		||||
                                  + head_offset * block_size
 | 
			
		||||
                                  + block_offset;
 | 
			
		||||
    const int64_t tgt_key_idx =
 | 
			
		||||
        block_idx * num_heads * (head_size / x) * block_size * x +
 | 
			
		||||
        head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
 | 
			
		||||
        block_offset * x + x_offset;
 | 
			
		||||
    const int64_t tgt_value_idx =
 | 
			
		||||
        block_idx * num_heads * head_size * block_size +
 | 
			
		||||
        head_idx * head_size * block_size + head_offset * block_size +
 | 
			
		||||
        block_offset;
 | 
			
		||||
    scalar_t tgt_key = key[src_key_idx];
 | 
			
		||||
    scalar_t tgt_value = value[src_value_idx];
 | 
			
		||||
    if constexpr (is_fp8_e5m2_kv_cache) {
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
 | 
			
		||||
      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
 | 
			
		||||
#else
 | 
			
		||||
      assert(false);
 | 
			
		||||
#endif
 | 
			
		||||
    } else {
 | 
			
		||||
    if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
 | 
			
		||||
      key_cache[tgt_key_idx] = tgt_key;
 | 
			
		||||
      value_cache[tgt_value_idx] = tgt_value;
 | 
			
		||||
    } else {
 | 
			
		||||
      key_cache[tgt_key_idx] =
 | 
			
		||||
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
 | 
			
		||||
      value_cache[tgt_value_idx] =
 | 
			
		||||
          fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void reshape_and_cache_flash_kernel(
 | 
			
		||||
    const scalar_t* __restrict__ key,    // [num_tokens, num_heads, head_size]
 | 
			
		||||
    const scalar_t* __restrict__ value,  // [num_tokens, num_heads, head_size]
 | 
			
		||||
    scalar_t* __restrict__ k_cache,      // [num_blocks, block_size, num_heads,
 | 
			
		||||
                                         // head_size]
 | 
			
		||||
    scalar_t* __restrict__ v_cache,      // [num_blocks, block_size, num_heads,
 | 
			
		||||
                                         // head_size]
 | 
			
		||||
    const int64_t* __restrict__ slot_mapping,  // [num_tokens]
 | 
			
		||||
    const int block_stride, const int key_stride, const int value_stride,
 | 
			
		||||
    const int num_heads, const int head_size, const int block_size) {
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  const int64_t slot_idx = slot_mapping[token_idx];
 | 
			
		||||
  // NOTE: slot_idx can be -1 if the token is padded
 | 
			
		||||
  if (slot_idx < 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
  const int64_t block_idx = slot_idx / block_size;
 | 
			
		||||
  const int64_t block_offset = slot_idx % block_size;
 | 
			
		||||
  const int n = num_heads * head_size;
 | 
			
		||||
  for (int i = threadIdx.x; i < n; i += blockDim.x) {
 | 
			
		||||
    const int64_t src_key_idx = token_idx * key_stride + i;
 | 
			
		||||
    const int64_t src_value_idx = token_idx * value_stride + i;
 | 
			
		||||
    const int head_idx = i / head_size;
 | 
			
		||||
    const int head_offset = i % head_size;
 | 
			
		||||
    const int64_t tgt_value_idx = block_idx * block_stride +
 | 
			
		||||
                                  block_offset * num_heads * head_size +
 | 
			
		||||
                                  head_idx * head_size + head_offset;
 | 
			
		||||
    k_cache[tgt_value_idx] = key[src_key_idx];
 | 
			
		||||
    v_cache[tgt_value_idx] = value[src_value_idx];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE)                                \
 | 
			
		||||
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
 | 
			
		||||
    reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
 | 
			
		||||
    reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
 | 
			
		||||
    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
 | 
			
		||||
    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \
 | 
			
		||||
    slot_mapping.data_ptr<int64_t>(),                                                              \
 | 
			
		||||
    key_stride,                                                                                    \
 | 
			
		||||
    value_stride,                                                                                  \
 | 
			
		||||
    num_heads,                                                                                     \
 | 
			
		||||
    head_size,                                                                                     \
 | 
			
		||||
    block_size,                                                                                    \
 | 
			
		||||
    x);
 | 
			
		||||
// KV_T is the stored data type of kv-cache.
 | 
			
		||||
// CACHE_T is the data type of key and value tensors.
 | 
			
		||||
// KV_DTYPE is the real data type of kv-cache.
 | 
			
		||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)               \
 | 
			
		||||
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE>             \
 | 
			
		||||
      <<<grid, block, 0, stream>>>(                                   \
 | 
			
		||||
          reinterpret_cast<KV_T*>(key.data_ptr()),                    \
 | 
			
		||||
          reinterpret_cast<KV_T*>(value.data_ptr()),                  \
 | 
			
		||||
          reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),           \
 | 
			
		||||
          reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),         \
 | 
			
		||||
          slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
 | 
			
		||||
          num_heads, head_size, block_size, x, kv_scale);
 | 
			
		||||
 | 
			
		||||
void reshape_and_cache(
 | 
			
		||||
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
 | 
			
		||||
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
 | 
			
		||||
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
  torch::Tensor& slot_mapping,  // [num_tokens]
 | 
			
		||||
  const std::string& kv_cache_dtype)
 | 
			
		||||
{
 | 
			
		||||
    torch::Tensor& key,    // [num_tokens, num_heads, head_size]
 | 
			
		||||
    torch::Tensor& value,  // [num_tokens, num_heads, head_size]
 | 
			
		||||
    torch::Tensor&
 | 
			
		||||
        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
    torch::Tensor&
 | 
			
		||||
        value_cache,  // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
    torch::Tensor& slot_mapping,  // [num_tokens]
 | 
			
		||||
    const std::string& kv_cache_dtype, const double kv_scale) {
 | 
			
		||||
  int num_tokens = key.size(0);
 | 
			
		||||
  int num_heads = key.size(1);
 | 
			
		||||
  int head_size = key.size(2);
 | 
			
		||||
@ -246,57 +272,80 @@ void reshape_and_cache(
 | 
			
		||||
  dim3 block(std::min(num_heads * head_size, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  if (kv_cache_dtype == "auto") {
 | 
			
		||||
    if (key.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(float, float, false);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
 | 
			
		||||
    }
 | 
			
		||||
  } else if (kv_cache_dtype == "fp8_e5m2") {
 | 
			
		||||
    if (key.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
 | 
			
		||||
  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
 | 
			
		||||
                             CALL_RESHAPE_AND_CACHE)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void reshape_and_cache_flash(
 | 
			
		||||
    torch::Tensor& key,      // [num_tokens, num_heads, head_size]
 | 
			
		||||
    torch::Tensor& value,    // [num_tokens, num_heads, head_size]
 | 
			
		||||
    torch::Tensor& k_cache,  // [num_blocks, block_size, num_heads, head_size]
 | 
			
		||||
    torch::Tensor& v_cache,  // [num_blocks, block_size, num_heads, head_size]
 | 
			
		||||
    torch::Tensor& slot_mapping,  // [num_tokens]
 | 
			
		||||
    const std::string& kv_cache_dtype) {
 | 
			
		||||
  // FIXME: only support auto datatype, does not support fp8
 | 
			
		||||
  if (kv_cache_dtype != "auto") {
 | 
			
		||||
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
 | 
			
		||||
  }
 | 
			
		||||
  int num_tokens = key.size(0);
 | 
			
		||||
  int num_heads = key.size(1);
 | 
			
		||||
  int head_size = key.size(2);
 | 
			
		||||
  int block_size = k_cache.size(1);
 | 
			
		||||
 | 
			
		||||
  int key_stride = key.stride(0);
 | 
			
		||||
  int value_stride = value.stride(0);
 | 
			
		||||
  int block_stride = k_cache.stride(0);
 | 
			
		||||
  TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(num_heads * head_size, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
      key.scalar_type(), "reshape_and_cache_flash", [&] {
 | 
			
		||||
        vllm::reshape_and_cache_flash_kernel<scalar_t>
 | 
			
		||||
            <<<grid, block, 0, stream>>>(
 | 
			
		||||
                key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
 | 
			
		||||
                k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
 | 
			
		||||
                slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
 | 
			
		||||
                value_stride, num_heads, head_size, block_size);
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
template<typename Tout, typename Tin>
 | 
			
		||||
__global__ void convert_fp8_e5m2_kernel(
 | 
			
		||||
  const Tin* __restrict__ src_cache,
 | 
			
		||||
  Tout* __restrict__ dst_cache,
 | 
			
		||||
  const int64_t block_stride) {
 | 
			
		||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
 | 
			
		||||
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
 | 
			
		||||
                                   Tout* __restrict__ dst_cache,
 | 
			
		||||
                                   const float kv_scale,
 | 
			
		||||
                                   const int64_t block_stride) {
 | 
			
		||||
  const int64_t block_idx = blockIdx.x;
 | 
			
		||||
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
 | 
			
		||||
    int64_t idx = block_idx * block_stride + i;
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
    dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
 | 
			
		||||
#else
 | 
			
		||||
    assert(false);
 | 
			
		||||
#endif
 | 
			
		||||
    dst_cache[idx] =
 | 
			
		||||
        fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin)                                 \
 | 
			
		||||
  vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
 | 
			
		||||
    reinterpret_cast<Tin*>(src_cache.data_ptr()),                        \
 | 
			
		||||
    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                       \
 | 
			
		||||
    block_stride);
 | 
			
		||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)                                \
 | 
			
		||||
  vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
 | 
			
		||||
      reinterpret_cast<Tin*>(src_cache.data_ptr()),                          \
 | 
			
		||||
      reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
 | 
			
		||||
 | 
			
		||||
// Only for testing.
 | 
			
		||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
 | 
			
		||||
                 const double kv_scale, const std::string& kv_cache_dtype) {
 | 
			
		||||
  torch::Device src_device = src_cache.device();
 | 
			
		||||
  torch::Device dst_device = dst_cache.device();
 | 
			
		||||
  TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
 | 
			
		||||
  TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
 | 
			
		||||
  TORCH_CHECK(src_device.index() == dst_device.index(),
 | 
			
		||||
              "src and dst must be on the same GPU");
 | 
			
		||||
  at::cuda::OptionalCUDAGuard device_guard(src_device);
 | 
			
		||||
 | 
			
		||||
void convert_fp8_e5m2(
 | 
			
		||||
  torch::Tensor& src_cache,
 | 
			
		||||
  torch::Tensor& dst_cache)
 | 
			
		||||
{
 | 
			
		||||
  int64_t num_blocks = src_cache.size(0);
 | 
			
		||||
  int64_t block_stride = src_cache.stride(0);
 | 
			
		||||
 | 
			
		||||
@ -304,17 +353,37 @@ void convert_fp8_e5m2(
 | 
			
		||||
  dim3 block(std::min(block_stride, int64_t(512)));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  if (src_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint8_t, float);
 | 
			
		||||
  } else if (src_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
 | 
			
		||||
  } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
 | 
			
		||||
  } else if (dst_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(float, uint8_t);
 | 
			
		||||
  } else if (dst_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
 | 
			
		||||
  } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
 | 
			
		||||
  if (kv_cache_dtype == "auto") {
 | 
			
		||||
    if (src_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
 | 
			
		||||
    } else if (src_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
 | 
			
		||||
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
 | 
			
		||||
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
 | 
			
		||||
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
 | 
			
		||||
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
 | 
			
		||||
    }
 | 
			
		||||
  } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
 | 
			
		||||
    if (src_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
 | 
			
		||||
    } else if (src_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
 | 
			
		||||
    } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
 | 
			
		||||
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
 | 
			
		||||
    } else if (dst_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
 | 
			
		||||
    } else if (dst_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
 | 
			
		||||
    } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
 | 
			
		||||
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										144
									
								
								csrc/cpu/activation.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								csrc/cpu/activation.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,144 @@
 | 
			
		||||
#include "cpu_types.hpp"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
 | 
			
		||||
          bool is_gated>
 | 
			
		||||
void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
 | 
			
		||||
                       scalar_t* __restrict__ output) {
 | 
			
		||||
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
 | 
			
		||||
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(d % VEC_ELEM_NUM == 0);
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for
 | 
			
		||||
  for (int i = 0; i < num_tokens; ++i) {
 | 
			
		||||
    for (int j = 0; j < d; j += VEC_ELEM_NUM) {
 | 
			
		||||
      int start = i * d;
 | 
			
		||||
      if constexpr (is_gated) {
 | 
			
		||||
        start *= 2;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      const scalar_vec_t x(input + start + j);
 | 
			
		||||
      const vec_op::FP32Vec8 f32_x(x);
 | 
			
		||||
      vec_op::FP32Vec8 f32_ans = func(f32_x);
 | 
			
		||||
 | 
			
		||||
      if constexpr (is_gated) {
 | 
			
		||||
        const scalar_vec_t y(input + start + d + j);
 | 
			
		||||
        const vec_op::FP32Vec8 f32_y(y);
 | 
			
		||||
        f32_ans = f32_y * f32_ans;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      const scalar_vec_t result(f32_ans);
 | 
			
		||||
      result.save(output + i * d + j);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
 | 
			
		||||
  const vec_op::FP32Vec8 zeros(0.0);
 | 
			
		||||
  const vec_op::FP32Vec8 ones(1.0);
 | 
			
		||||
  return x / (ones + (zeros - x).exp());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
 | 
			
		||||
  const vec_op::FP32Vec8 ones(1.0);
 | 
			
		||||
  const vec_op::FP32Vec8 w1(0.79788456f);
 | 
			
		||||
  const vec_op::FP32Vec8 w2(0.044715f);
 | 
			
		||||
  const vec_op::FP32Vec8 w3(0.5);
 | 
			
		||||
  const vec_op::FP32Vec8 x3 = x * x * x;
 | 
			
		||||
  const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
 | 
			
		||||
  return w3 * x * (ones + t);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
 | 
			
		||||
  const vec_op::FP32Vec8 ones(1.0);
 | 
			
		||||
  const vec_op::FP32Vec8 w1(0.79788456f);
 | 
			
		||||
  const vec_op::FP32Vec8 w2(0.044715f);
 | 
			
		||||
  const vec_op::FP32Vec8 w3(0.5);
 | 
			
		||||
  const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
 | 
			
		||||
  return w3 * x * (ones + t);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
 | 
			
		||||
  const vec_op::FP32Vec8 ones(1.0);
 | 
			
		||||
  const vec_op::FP32Vec8 w1(M_SQRT1_2);
 | 
			
		||||
  const vec_op::FP32Vec8 w2(0.5);
 | 
			
		||||
  return x * w2 * (ones + (x * w1).er());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
 | 
			
		||||
  const vec_op::FP32Vec8 ones(1.0);
 | 
			
		||||
  const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
 | 
			
		||||
  const vec_op::FP32Vec8 w2(0.5);
 | 
			
		||||
  const vec_op::FP32Vec8 w3(0.044715);
 | 
			
		||||
  const vec_op::FP32Vec8 x_3 = x * x * x;
 | 
			
		||||
  const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
 | 
			
		||||
  return x * w2 * (ones + inner.tanh());
 | 
			
		||||
}
 | 
			
		||||
};  // namespace
 | 
			
		||||
 | 
			
		||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
 | 
			
		||||
  int num_tokens = input.numel() / input.size(-1);
 | 
			
		||||
  int d = input.size(-1) / 2;
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
 | 
			
		||||
    CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
 | 
			
		||||
    activation_kernel<scalar_t, silu_act, true>(
 | 
			
		||||
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
 | 
			
		||||
    CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gelu_and_mul(torch::Tensor& out,    // [..., d]
 | 
			
		||||
                  torch::Tensor& input)  // [..., 2 * d]
 | 
			
		||||
{
 | 
			
		||||
  int num_tokens = input.numel() / input.size(-1);
 | 
			
		||||
  int d = input.size(-1) / 2;
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
 | 
			
		||||
    CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
 | 
			
		||||
    activation_kernel<scalar_t, gelu_act, true>(
 | 
			
		||||
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
 | 
			
		||||
    CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gelu_tanh_and_mul(torch::Tensor& out,    // [..., d]
 | 
			
		||||
                       torch::Tensor& input)  // [..., 2 * d]
 | 
			
		||||
{
 | 
			
		||||
  int num_tokens = input.numel() / input.size(-1);
 | 
			
		||||
  int d = input.size(-1) / 2;
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
      input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
 | 
			
		||||
        CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
 | 
			
		||||
        activation_kernel<scalar_t, gelu_tanh_act, true>(
 | 
			
		||||
            num_tokens, d, input.data_ptr<scalar_t>(),
 | 
			
		||||
            out.data_ptr<scalar_t>());
 | 
			
		||||
        CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gelu_new(torch::Tensor& out, torch::Tensor& input) {
 | 
			
		||||
  int num_tokens = input.numel() / input.size(-1);
 | 
			
		||||
  int d = input.size(-1);
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
 | 
			
		||||
    CPU_KERNEL_GUARD_IN(gelu_new_impl)
 | 
			
		||||
    activation_kernel<scalar_t, gelu_new_act, false>(
 | 
			
		||||
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
 | 
			
		||||
    CPU_KERNEL_GUARD_OUT(gelu_new_impl)
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
 | 
			
		||||
  int num_tokens = input.numel() / input.size(-1);
 | 
			
		||||
  int d = input.size(-1);
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
 | 
			
		||||
    CPU_KERNEL_GUARD_IN(gelu_fast_impl)
 | 
			
		||||
    activation_kernel<scalar_t, gelu_fast_act, false>(
 | 
			
		||||
        num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
 | 
			
		||||
    CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										758
									
								
								csrc/cpu/attention.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										758
									
								
								csrc/cpu/attention.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,758 @@
 | 
			
		||||
#include "cpu_types.hpp"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
struct KernelVecType {
 | 
			
		||||
  using q_load_vec_type = void;
 | 
			
		||||
  using q_vec_type = void;
 | 
			
		||||
  using k_load_vec_type = void;
 | 
			
		||||
  using k_vec_type = void;
 | 
			
		||||
  using qk_acc_vec_type = void;
 | 
			
		||||
  using v_load_vec_type = void;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct KernelVecType<float> {
 | 
			
		||||
  using q_load_vec_type = vec_op::FP32Vec4;
 | 
			
		||||
  using q_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using k_load_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using k_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using qk_acc_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using v_load_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#ifdef __AVX512BF16__
 | 
			
		||||
template <>
 | 
			
		||||
struct KernelVecType<c10::BFloat16> {
 | 
			
		||||
  using q_load_vec_type = vec_op::BF16Vec8;
 | 
			
		||||
  using q_vec_type = vec_op::BF16Vec32;
 | 
			
		||||
  using k_load_vec_type = vec_op::BF16Vec32;
 | 
			
		||||
  using k_vec_type = vec_op::BF16Vec32;
 | 
			
		||||
  using qk_acc_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using v_load_vec_type = vec_op::BF16Vec16;
 | 
			
		||||
};
 | 
			
		||||
#else
 | 
			
		||||
template <>
 | 
			
		||||
struct KernelVecType<c10::BFloat16> {
 | 
			
		||||
  using q_load_vec_type = vec_op::BF16Vec8;
 | 
			
		||||
  using q_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using k_load_vec_type = vec_op::BF16Vec16;
 | 
			
		||||
  using k_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using qk_acc_vec_type = vec_op::FP32Vec16;
 | 
			
		||||
  using v_load_vec_type = vec_op::BF16Vec16;
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
 | 
			
		||||
                                           const int capacity) {
 | 
			
		||||
  T max = data[0];
 | 
			
		||||
  for (int i = 1; i < size; ++i) {
 | 
			
		||||
    max = max >= data[i] ? max : data[i];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  T sum = 0;
 | 
			
		||||
  for (int i = 0; i < size; ++i) {
 | 
			
		||||
    data[i] = std::exp(data[i] - max);
 | 
			
		||||
    sum += data[i];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int i = 0;
 | 
			
		||||
  for (; i < size; ++i) {
 | 
			
		||||
    data[i] /= sum;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (; i < capacity; ++i) {
 | 
			
		||||
    data[i] = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return {max, sum};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
 | 
			
		||||
                                                const int capacity,
 | 
			
		||||
                                                const float alibi_slope,
 | 
			
		||||
                                                const int start_index,
 | 
			
		||||
                                                const int seq_len) {
 | 
			
		||||
  data[0] += alibi_slope * (start_index - seq_len + 1);
 | 
			
		||||
  T max = data[0];
 | 
			
		||||
  for (int i = 1; i < size; ++i) {
 | 
			
		||||
    T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
 | 
			
		||||
    data[i] = qk;
 | 
			
		||||
    max = max >= qk ? max : qk;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  T sum = 0;
 | 
			
		||||
  for (int i = 0; i < size; ++i) {
 | 
			
		||||
    data[i] = std::exp(data[i] - max);
 | 
			
		||||
    sum += data[i];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int i = 0;
 | 
			
		||||
  for (; i < size; ++i) {
 | 
			
		||||
    data[i] /= sum;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (; i < capacity; ++i) {
 | 
			
		||||
    data[i] = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return {max, sum};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
 | 
			
		||||
                                        const int size) {
 | 
			
		||||
  T max = max_data[0];
 | 
			
		||||
  for (int i = 1; i < size; ++i) {
 | 
			
		||||
    max = max >= max_data[i] ? max : max_data[i];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  T rescaled_sum = 0;
 | 
			
		||||
  for (int i = 0; i < size; ++i) {
 | 
			
		||||
    T rescale_factor = std::exp(max_data[i] - max);
 | 
			
		||||
    rescaled_sum += rescale_factor * sum_data[i];
 | 
			
		||||
    sum_data[i] *= rescale_factor;
 | 
			
		||||
  }
 | 
			
		||||
  for (int i = 0; i < size; ++i) {
 | 
			
		||||
    sum_data[i] /= rescaled_sum + 1e-8;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int x>
 | 
			
		||||
struct reduceQKBlockKernel {
 | 
			
		||||
  using q_load_vec_type = typename KernelVecType<scalar_t>::q_load_vec_type;
 | 
			
		||||
  using q_vec_type = typename KernelVecType<scalar_t>::q_vec_type;
 | 
			
		||||
  using k_load_vec_type = typename KernelVecType<scalar_t>::k_load_vec_type;
 | 
			
		||||
  using k_vec_type = typename KernelVecType<scalar_t>::k_vec_type;
 | 
			
		||||
  using qk_acc_vec_type = typename KernelVecType<scalar_t>::qk_acc_vec_type;
 | 
			
		||||
 | 
			
		||||
  constexpr static int TOKEN_PER_GROUP = k_load_vec_type::get_elem_num() / x;
 | 
			
		||||
  constexpr static int MAX_GROUP_NUM = 16 / TOKEN_PER_GROUP;
 | 
			
		||||
  constexpr static int UNROLL_GROUP_NUM = MAX_GROUP_NUM / 4;
 | 
			
		||||
 | 
			
		||||
  static_assert(MAX_GROUP_NUM == 8 || MAX_GROUP_NUM == 4);
 | 
			
		||||
  static_assert(k_load_vec_type::get_elem_num() % x == 0);
 | 
			
		||||
  static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
 | 
			
		||||
 | 
			
		||||
  FORCE_INLINE static void call(const scalar_t* __restrict__ q,
 | 
			
		||||
                                const scalar_t* __restrict__ k_block,
 | 
			
		||||
                                float* __restrict__ logits, float scale,
 | 
			
		||||
                                const int token_num) {
 | 
			
		||||
    const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
 | 
			
		||||
 | 
			
		||||
    qk_acc_vec_type group_accums[MAX_GROUP_NUM];
 | 
			
		||||
    if (token_num == BLOCK_SIZE) {
 | 
			
		||||
      for (int q_offset = 0; q_offset < HEAD_SIZE;
 | 
			
		||||
           q_offset += x, k_block += x * BLOCK_SIZE) {
 | 
			
		||||
        q_load_vec_type q_load_group_vec(q + q_offset);
 | 
			
		||||
        q_vec_type q_group_vec(q_load_group_vec);
 | 
			
		||||
 | 
			
		||||
        vec_op::unroll_loop<int, MAX_GROUP_NUM>(
 | 
			
		||||
            [k_block, &q_group_vec, &group_accums](int token_group_idx) {
 | 
			
		||||
              k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
 | 
			
		||||
                                                             TOKEN_PER_GROUP);
 | 
			
		||||
              k_vec_type k_group_vec(k_load_group_vec);
 | 
			
		||||
              vec_op::fma(group_accums[token_group_idx], q_group_vec,
 | 
			
		||||
                          k_group_vec);
 | 
			
		||||
              vec_op::prefetch(k_block + x * BLOCK_SIZE +
 | 
			
		||||
                               token_group_idx * x * TOKEN_PER_GROUP);
 | 
			
		||||
            });
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      for (int q_offset = 0; q_offset < HEAD_SIZE;
 | 
			
		||||
           q_offset += x, k_block += x * BLOCK_SIZE) {
 | 
			
		||||
        q_load_vec_type q_load_group_vec(q + q_offset);
 | 
			
		||||
        q_vec_type q_group_vec(q_load_group_vec);
 | 
			
		||||
        for (int token_group_start = 0; token_group_start < group_num;
 | 
			
		||||
             token_group_start += UNROLL_GROUP_NUM) {
 | 
			
		||||
          vec_op::unroll_loop<int, UNROLL_GROUP_NUM>(
 | 
			
		||||
              [token_group_start, k_block, &q_group_vec,
 | 
			
		||||
               &group_accums](int token_group_idx) {
 | 
			
		||||
                token_group_idx += token_group_start;
 | 
			
		||||
                k_load_vec_type k_load_group_vec(k_block + token_group_idx * x *
 | 
			
		||||
                                                               TOKEN_PER_GROUP);
 | 
			
		||||
                k_vec_type k_group_vec(k_load_group_vec);
 | 
			
		||||
                vec_op::fma(group_accums[token_group_idx], q_group_vec,
 | 
			
		||||
                            k_group_vec);
 | 
			
		||||
                vec_op::prefetch(k_block + x * BLOCK_SIZE +
 | 
			
		||||
                                 token_group_idx * x * TOKEN_PER_GROUP);
 | 
			
		||||
              });
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int token_group_idx = 0; token_group_idx < group_num;
 | 
			
		||||
         ++token_group_idx) {
 | 
			
		||||
      vec_op::unroll_loop<int, TOKEN_PER_GROUP>(
 | 
			
		||||
          [&group_accums, logits, scale, token_group_idx](int token_idx) {
 | 
			
		||||
            float dot_v =
 | 
			
		||||
                group_accums[token_group_idx]
 | 
			
		||||
                    .template reduce_sub_sum<qk_acc_vec_type::get_elem_num() /
 | 
			
		||||
                                             TOKEN_PER_GROUP>(token_idx);
 | 
			
		||||
            logits[token_group_idx * TOKEN_PER_GROUP + token_idx] =
 | 
			
		||||
                dot_v * scale;
 | 
			
		||||
          });
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
 | 
			
		||||
          int HEAD_PARTITION_SIZE, typename acc_t>
 | 
			
		||||
FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
 | 
			
		||||
                                   acc_t&& acc) {
 | 
			
		||||
  using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
 | 
			
		||||
  constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
 | 
			
		||||
  static_assert(BLOCK_SIZE == ELEM_NUM);
 | 
			
		||||
  vec_op::FP32Vec16 prob_vec(prob);
 | 
			
		||||
 | 
			
		||||
  vec_op::unroll_loop<int, HEAD_PARTITION_SIZE>([&](int head_elem_idx) {
 | 
			
		||||
    v_load_vec_type v_vec(v_block + BLOCK_SIZE * head_elem_idx);
 | 
			
		||||
    vec_op::FP32Vec16 fp32_v_vec(v_vec);
 | 
			
		||||
    acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
};  // namespace
 | 
			
		||||
 | 
			
		||||
// Paged attention v1
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
 | 
			
		||||
struct paged_attention_v1_impl {
 | 
			
		||||
  static void call(
 | 
			
		||||
      scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
 | 
			
		||||
      const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size]
 | 
			
		||||
      const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
 | 
			
		||||
                                             // head_size/x, block_size, x]
 | 
			
		||||
      const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
 | 
			
		||||
                                             // head_size, block_size]
 | 
			
		||||
      const int num_kv_heads, const float scale,
 | 
			
		||||
      const int* __restrict__ block_tables,  // [num_seqs,
 | 
			
		||||
                                             // max_num_blocks_per_seq]
 | 
			
		||||
      const int* __restrict__ seq_lens,      // [num_seqs]
 | 
			
		||||
      const int max_num_blocks_per_seq,
 | 
			
		||||
      const float* __restrict__ alibi_slopes,  // [num_heads]
 | 
			
		||||
      const int q_stride, const int kv_block_stride, const int kv_head_stride,
 | 
			
		||||
      const int num_seqs, const int num_heads) {
 | 
			
		||||
    constexpr int x = 16 / sizeof(scalar_t);
 | 
			
		||||
    const int num_queries_per_kv = num_heads / num_kv_heads;
 | 
			
		||||
 | 
			
		||||
    static_assert(BLOCK_SIZE == 16);
 | 
			
		||||
 | 
			
		||||
    int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
 | 
			
		||||
    int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
 | 
			
		||||
    TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
 | 
			
		||||
 | 
			
		||||
    const int parallel_work_item_num = omp_get_max_threads();
 | 
			
		||||
 | 
			
		||||
    size_t logits_bytes =
 | 
			
		||||
        parallel_work_item_num * max_seq_len_padded * sizeof(float);
 | 
			
		||||
    float* logits = (float*)std::aligned_alloc(
 | 
			
		||||
        64, logits_bytes);  // Cacheline alignment for each context token.
 | 
			
		||||
                            // [parallel_work_item_num, max_seq_len_padded]
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for collapse(2) schedule(dynamic, 1)
 | 
			
		||||
    for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
 | 
			
		||||
      for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
 | 
			
		||||
        int seq_len = seq_lens[seq_idx];
 | 
			
		||||
        const int* seq_block_table =
 | 
			
		||||
            block_tables + max_num_blocks_per_seq * seq_idx;
 | 
			
		||||
        const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
 | 
			
		||||
        const int64_t kv_head_idx = head_idx / num_queries_per_kv;
 | 
			
		||||
        const scalar_t* __restrict__ q_vec_ptr =
 | 
			
		||||
            q + seq_idx * q_stride + head_idx * HEAD_SIZE;
 | 
			
		||||
        const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
 | 
			
		||||
        float* __restrict__ thread_block_logits =
 | 
			
		||||
            logits + omp_get_thread_num() * max_seq_len_padded;
 | 
			
		||||
 | 
			
		||||
        // Compute logits
 | 
			
		||||
        for (int block_idx = 0; block_idx < block_num; ++block_idx) {
 | 
			
		||||
          const int64_t physical_block_idx = seq_block_table[block_idx];
 | 
			
		||||
          const scalar_t* __restrict__ k_block_cache_ptr =
 | 
			
		||||
              k_cache + physical_block_idx * kv_block_stride +
 | 
			
		||||
              kv_head_idx * kv_head_stride;
 | 
			
		||||
          float* __restrict__ head_block_logits =
 | 
			
		||||
              thread_block_logits + block_idx * BLOCK_SIZE;
 | 
			
		||||
 | 
			
		||||
          reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
 | 
			
		||||
              q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
 | 
			
		||||
              block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Compute softmax
 | 
			
		||||
        if (alibi_slopes) {
 | 
			
		||||
          reduceSoftmaxAlibi(thread_block_logits, seq_len,
 | 
			
		||||
                             block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
 | 
			
		||||
                             seq_len);
 | 
			
		||||
        } else {
 | 
			
		||||
          reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Compute value
 | 
			
		||||
        constexpr int head_elem_num_per_partition = 16;
 | 
			
		||||
        constexpr int head_partition_num =
 | 
			
		||||
            HEAD_SIZE / head_elem_num_per_partition;
 | 
			
		||||
        for (int head_part_idx = 0; head_part_idx < head_partition_num;
 | 
			
		||||
             ++head_part_idx) {
 | 
			
		||||
          vec_op::FP32Vec16 accums[head_elem_num_per_partition];
 | 
			
		||||
          scalar_t* __restrict__ out_ptr =
 | 
			
		||||
              out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
 | 
			
		||||
              head_part_idx * head_elem_num_per_partition;
 | 
			
		||||
          for (int block_idx = 0; block_idx < block_num; ++block_idx) {
 | 
			
		||||
            const int64_t physical_block_idx = seq_block_table[block_idx];
 | 
			
		||||
            const float* __restrict__ prob_vec_ptr =
 | 
			
		||||
                thread_block_logits + block_idx * BLOCK_SIZE;
 | 
			
		||||
            const scalar_t* __restrict__ v_block_cache_ptr =
 | 
			
		||||
                v_cache + physical_block_idx * kv_block_stride +
 | 
			
		||||
                kv_head_idx * kv_head_stride +
 | 
			
		||||
                BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
 | 
			
		||||
            reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
 | 
			
		||||
                             head_elem_num_per_partition>(
 | 
			
		||||
                prob_vec_ptr, v_block_cache_ptr, accums);
 | 
			
		||||
 | 
			
		||||
            if (block_idx != block_num - 1) {
 | 
			
		||||
              const int64_t next_physical_block_idx =
 | 
			
		||||
                  seq_block_table[block_idx + 1];
 | 
			
		||||
              const scalar_t* __restrict__ next_v_block_cache_ptr =
 | 
			
		||||
                  v_cache + next_physical_block_idx * kv_block_stride +
 | 
			
		||||
                  kv_head_idx * kv_head_stride +
 | 
			
		||||
                  BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
 | 
			
		||||
              vec_op::unroll_loop<int, head_elem_num_per_partition>(
 | 
			
		||||
                  [&](int head_elem_idx) {
 | 
			
		||||
                    if (head_elem_idx % 2 == 0) {
 | 
			
		||||
                      vec_op::prefetch(next_v_block_cache_ptr +
 | 
			
		||||
                                       BLOCK_SIZE * head_elem_idx);
 | 
			
		||||
                    }
 | 
			
		||||
                  });
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          vec_op::unroll_loop<int, head_elem_num_per_partition>(
 | 
			
		||||
              [&](int head_elem_idx) {
 | 
			
		||||
                float value = accums[head_elem_idx].reduce_sum();
 | 
			
		||||
                vec_op::storeFP32(value, out_ptr + head_elem_idx);
 | 
			
		||||
              });
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    std::free(logits);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE)                   \
 | 
			
		||||
  paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call(                     \
 | 
			
		||||
      out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
 | 
			
		||||
      block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,                  \
 | 
			
		||||
      alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs,   \
 | 
			
		||||
      num_heads);
 | 
			
		||||
 | 
			
		||||
template <typename T, int BLOCK_SIZE>
 | 
			
		||||
void paged_attention_v1_impl_launcher(
 | 
			
		||||
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
 | 
			
		||||
    torch::Tensor& value_cache, int num_kv_heads, float scale,
 | 
			
		||||
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
 | 
			
		||||
    const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
			
		||||
  int num_seqs = query.size(0);
 | 
			
		||||
  int num_heads = query.size(1);
 | 
			
		||||
  int head_size = query.size(2);
 | 
			
		||||
  int max_num_blocks_per_seq = block_tables.size(1);
 | 
			
		||||
  int q_stride = query.stride(0);
 | 
			
		||||
  int kv_block_stride = key_cache.stride(0);
 | 
			
		||||
  int kv_head_stride = key_cache.stride(1);
 | 
			
		||||
 | 
			
		||||
  // NOTE: alibi_slopes is optional.
 | 
			
		||||
  const float* alibi_slopes_ptr =
 | 
			
		||||
      alibi_slopes
 | 
			
		||||
          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
 | 
			
		||||
          : nullptr;
 | 
			
		||||
 | 
			
		||||
  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
 | 
			
		||||
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
 | 
			
		||||
  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
 | 
			
		||||
  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
 | 
			
		||||
  int* block_tables_ptr = block_tables.data_ptr<int>();
 | 
			
		||||
  int* seq_lens_ptr = seq_lens.data_ptr<int>();
 | 
			
		||||
 | 
			
		||||
  switch (head_size) {
 | 
			
		||||
    case 64:
 | 
			
		||||
      LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 80:
 | 
			
		||||
      LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 96:
 | 
			
		||||
      LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 112:
 | 
			
		||||
      LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 128:
 | 
			
		||||
      LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 192:
 | 
			
		||||
      LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 256:
 | 
			
		||||
      LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported head size: ", head_size);
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE)                               \
 | 
			
		||||
  paged_attention_v1_impl_launcher<T, BLOCK_SIZE>(                           \
 | 
			
		||||
      out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
 | 
			
		||||
      seq_lens, max_seq_len, alibi_slopes);
 | 
			
		||||
 | 
			
		||||
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T)                     \
 | 
			
		||||
  switch (block_size) {                                           \
 | 
			
		||||
    case 16:                                                      \
 | 
			
		||||
      CALL_V1_KERNEL_LAUNCHER(T, 16);                             \
 | 
			
		||||
      break;                                                      \
 | 
			
		||||
    default:                                                      \
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
 | 
			
		||||
      break;                                                      \
 | 
			
		||||
  }
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void paged_attention_v1(
 | 
			
		||||
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
 | 
			
		||||
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
 | 
			
		||||
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
 | 
			
		||||
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
 | 
			
		||||
    const int64_t blocksparse_local_blocks,
 | 
			
		||||
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
 | 
			
		||||
    const int64_t blocksparse_head_sliding_step) {
 | 
			
		||||
  TORCH_CHECK(kv_scale == 1.0f);
 | 
			
		||||
  TORCH_CHECK(blocksparse_vert_stride <= 1,
 | 
			
		||||
              "CPU backend does not support blocksparse attention yet.");
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
 | 
			
		||||
                               [&] {
 | 
			
		||||
                                 CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
 | 
			
		||||
                                 CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
 | 
			
		||||
                                 CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl)
 | 
			
		||||
                               });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Paged attention v2
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
 | 
			
		||||
struct paged_attention_v2_impl {
 | 
			
		||||
  static void call(
 | 
			
		||||
      scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
 | 
			
		||||
      float* __restrict__ exp_sums,          // [num_seqs, num_heads,
 | 
			
		||||
                                             // max_num_partitions]
 | 
			
		||||
      float* __restrict__ max_logits,        // [num_seqs, num_heads,
 | 
			
		||||
                                             // max_num_partitions]
 | 
			
		||||
      scalar_t* __restrict__ tmp_out,        // [num_seqs, num_heads,
 | 
			
		||||
                                             // max_num_partitions, head_size]
 | 
			
		||||
      const scalar_t* __restrict__ q,        // [num_seqs, num_heads, head_size]
 | 
			
		||||
      const scalar_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
 | 
			
		||||
                                             // head_size/x, block_size, x]
 | 
			
		||||
      const scalar_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
 | 
			
		||||
                                             // head_size, block_size]
 | 
			
		||||
      const int num_kv_heads, const float scale,
 | 
			
		||||
      const int* __restrict__ block_tables,  // [num_seqs,
 | 
			
		||||
                                             // max_num_blocks_per_seq]
 | 
			
		||||
      const int* __restrict__ seq_lens,      // [num_seqs]
 | 
			
		||||
      const int max_num_blocks_per_seq,
 | 
			
		||||
      const float* __restrict__ alibi_slopes,  // [num_heads]
 | 
			
		||||
      const int q_stride, const int kv_block_stride, const int kv_head_stride,
 | 
			
		||||
      const int num_seqs, const int num_heads, const int max_num_partitions) {
 | 
			
		||||
    constexpr int x = 16 / sizeof(scalar_t);
 | 
			
		||||
    const int num_queries_per_kv = num_heads / num_kv_heads;
 | 
			
		||||
 | 
			
		||||
    static_assert(BLOCK_SIZE == 16);
 | 
			
		||||
    static_assert(PARTITION_SIZE * sizeof(float) % 64 == 0);
 | 
			
		||||
    static_assert(PARTITION_SIZE % BLOCK_SIZE == 0);
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for collapse(3) schedule(static, 1)
 | 
			
		||||
    for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
 | 
			
		||||
      for (int partition_idx = 0; partition_idx < max_num_partitions;
 | 
			
		||||
           ++partition_idx) {
 | 
			
		||||
        for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
 | 
			
		||||
          const int seq_len = seq_lens[seq_idx];
 | 
			
		||||
          const int start_token_idx = partition_idx * PARTITION_SIZE;
 | 
			
		||||
 | 
			
		||||
          if (start_token_idx >= seq_len) continue;
 | 
			
		||||
 | 
			
		||||
          const int partition_num =
 | 
			
		||||
              (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
 | 
			
		||||
          const bool no_reduce = (partition_num == 1);
 | 
			
		||||
          const int token_num =
 | 
			
		||||
              (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
 | 
			
		||||
               start_token_idx);
 | 
			
		||||
          const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
 | 
			
		||||
          const int last_block_token_num =
 | 
			
		||||
              token_num - (block_num - 1) * BLOCK_SIZE;
 | 
			
		||||
          const int* seq_block_table = block_tables +
 | 
			
		||||
                                       max_num_blocks_per_seq * seq_idx +
 | 
			
		||||
                                       start_token_idx / BLOCK_SIZE;
 | 
			
		||||
          const int64_t kv_head_idx = head_idx / num_queries_per_kv;
 | 
			
		||||
          const scalar_t* __restrict__ q_vec_ptr =
 | 
			
		||||
              q + seq_idx * q_stride + head_idx * HEAD_SIZE;
 | 
			
		||||
 | 
			
		||||
          float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
 | 
			
		||||
 | 
			
		||||
          // Compute logits
 | 
			
		||||
          for (int block_idx = 0; block_idx < block_num; ++block_idx) {
 | 
			
		||||
            const int64_t physical_block_idx = seq_block_table[block_idx];
 | 
			
		||||
            const scalar_t* __restrict__ k_block_cache_ptr =
 | 
			
		||||
                k_cache + physical_block_idx * kv_block_stride +
 | 
			
		||||
                kv_head_idx * kv_head_stride;
 | 
			
		||||
            float* __restrict__ head_block_logits =
 | 
			
		||||
                logits + block_idx * BLOCK_SIZE;
 | 
			
		||||
 | 
			
		||||
            reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
 | 
			
		||||
                q_vec_ptr, k_block_cache_ptr, head_block_logits, scale,
 | 
			
		||||
                block_idx == block_num - 1 ? last_block_token_num : BLOCK_SIZE);
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          std::pair<float, float> max_and_sum;
 | 
			
		||||
          if (alibi_slopes) {
 | 
			
		||||
            max_and_sum = reduceSoftmaxAlibi(
 | 
			
		||||
                logits, token_num, block_num * BLOCK_SIZE,
 | 
			
		||||
                alibi_slopes[head_idx], start_token_idx, seq_len);
 | 
			
		||||
          } else {
 | 
			
		||||
            max_and_sum =
 | 
			
		||||
                reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          auto&& [max_logit, exp_sum] = max_and_sum;
 | 
			
		||||
 | 
			
		||||
          scalar_t* __restrict__ output_buffer = nullptr;
 | 
			
		||||
          if (!no_reduce) {
 | 
			
		||||
            auto idx = seq_idx * num_heads * max_num_partitions +
 | 
			
		||||
                       head_idx * max_num_partitions + partition_idx;
 | 
			
		||||
            max_logits[idx] = max_logit;
 | 
			
		||||
            exp_sums[idx] = exp_sum;
 | 
			
		||||
            output_buffer =
 | 
			
		||||
                tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
 | 
			
		||||
                head_idx * max_num_partitions * HEAD_SIZE +
 | 
			
		||||
                partition_idx * HEAD_SIZE;
 | 
			
		||||
          } else {
 | 
			
		||||
            output_buffer =
 | 
			
		||||
                out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          // Compute value
 | 
			
		||||
          constexpr int head_elem_num_per_partition = 16;
 | 
			
		||||
          constexpr int head_partition_num =
 | 
			
		||||
              HEAD_SIZE / head_elem_num_per_partition;
 | 
			
		||||
          for (int head_part_idx = 0; head_part_idx < head_partition_num;
 | 
			
		||||
               ++head_part_idx) {
 | 
			
		||||
            vec_op::FP32Vec16 accums[head_elem_num_per_partition];
 | 
			
		||||
            scalar_t* __restrict__ out_ptr =
 | 
			
		||||
                output_buffer + head_part_idx * head_elem_num_per_partition;
 | 
			
		||||
            for (int block_idx = 0; block_idx < block_num; ++block_idx) {
 | 
			
		||||
              const int64_t physical_block_idx = seq_block_table[block_idx];
 | 
			
		||||
              const float* __restrict__ prob_vec_ptr =
 | 
			
		||||
                  logits + block_idx * BLOCK_SIZE;
 | 
			
		||||
              const scalar_t* __restrict__ v_block_cache_ptr =
 | 
			
		||||
                  v_cache + physical_block_idx * kv_block_stride +
 | 
			
		||||
                  kv_head_idx * kv_head_stride +
 | 
			
		||||
                  BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
 | 
			
		||||
              reduceValueBlock<scalar_t, HEAD_SIZE, BLOCK_SIZE,
 | 
			
		||||
                               head_elem_num_per_partition>(
 | 
			
		||||
                  prob_vec_ptr, v_block_cache_ptr, accums);
 | 
			
		||||
 | 
			
		||||
              if (block_idx != block_num - 1) {
 | 
			
		||||
                const int64_t next_physical_block_idx =
 | 
			
		||||
                    seq_block_table[block_idx + 1];
 | 
			
		||||
                const scalar_t* __restrict__ next_v_block_cache_ptr =
 | 
			
		||||
                    v_cache + next_physical_block_idx * kv_block_stride +
 | 
			
		||||
                    kv_head_idx * kv_head_stride +
 | 
			
		||||
                    BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
 | 
			
		||||
                vec_op::unroll_loop<int, head_elem_num_per_partition>(
 | 
			
		||||
                    [&](int head_elem_idx) {
 | 
			
		||||
                      if (head_elem_idx % 2 == 0) {
 | 
			
		||||
                        vec_op::prefetch(next_v_block_cache_ptr +
 | 
			
		||||
                                         BLOCK_SIZE * head_elem_idx);
 | 
			
		||||
                      }
 | 
			
		||||
                    });
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            vec_op::unroll_loop<int, head_elem_num_per_partition>(
 | 
			
		||||
                [&](int head_elem_idx) {
 | 
			
		||||
                  float value = accums[head_elem_idx].reduce_sum();
 | 
			
		||||
                  vec_op::storeFP32(value, out_ptr + head_elem_idx);
 | 
			
		||||
                });
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Rescale partition softmax and store the factors to exp_sums
 | 
			
		||||
#pragma omp parallel for collapse(2) schedule(static, 1)
 | 
			
		||||
    for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
 | 
			
		||||
      for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
 | 
			
		||||
        const int seq_len = seq_lens[seq_idx];
 | 
			
		||||
        const int partition_num =
 | 
			
		||||
            (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
 | 
			
		||||
 | 
			
		||||
        if (partition_num == 1) continue;
 | 
			
		||||
 | 
			
		||||
        reducePartitonSoftmax(
 | 
			
		||||
            max_logits + seq_idx * num_heads * max_num_partitions +
 | 
			
		||||
                head_idx * max_num_partitions,
 | 
			
		||||
            exp_sums + seq_idx * num_heads * max_num_partitions +
 | 
			
		||||
                head_idx * max_num_partitions,
 | 
			
		||||
            partition_num);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Reduce values
 | 
			
		||||
    using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
 | 
			
		||||
    static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
 | 
			
		||||
    constexpr int head_elem_num_per_group =
 | 
			
		||||
        16;  // Note: didn't align with the cacheline size, due to some
 | 
			
		||||
             // HEAD_SIZE didn't align with 64 bytes
 | 
			
		||||
    static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
 | 
			
		||||
    constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
 | 
			
		||||
    const float* __restrict__ rescale_factors = exp_sums;
 | 
			
		||||
#pragma omp parallel for collapse(3) schedule(static, 1)
 | 
			
		||||
    for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
 | 
			
		||||
      for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
 | 
			
		||||
        for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
 | 
			
		||||
          const int seq_len = seq_lens[seq_idx];
 | 
			
		||||
          const int partition_num =
 | 
			
		||||
              (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
 | 
			
		||||
 | 
			
		||||
          if (partition_num == 1) continue;
 | 
			
		||||
 | 
			
		||||
          const float* __restrict__ seq_head_rescale_factors =
 | 
			
		||||
              rescale_factors + seq_idx * num_heads * max_num_partitions +
 | 
			
		||||
              head_idx * max_num_partitions;
 | 
			
		||||
          const scalar_t* __restrict__ seq_head_tmp_out =
 | 
			
		||||
              tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
 | 
			
		||||
              head_idx * max_num_partitions * HEAD_SIZE +
 | 
			
		||||
              group_idx * head_elem_num_per_group;
 | 
			
		||||
          scalar_t* __restrict__ seq_head_output =
 | 
			
		||||
              out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
 | 
			
		||||
              group_idx * head_elem_num_per_group;
 | 
			
		||||
 | 
			
		||||
          vec_op::FP32Vec16 acc;
 | 
			
		||||
          for (int i = 0; i < partition_num; ++i) {
 | 
			
		||||
            vec_op::FP32Vec16 rescale_factor(seq_head_rescale_factors[i]);
 | 
			
		||||
            v_load_vec_type value(seq_head_tmp_out + i * HEAD_SIZE);
 | 
			
		||||
            vec_op::FP32Vec16 fp32_value(value);
 | 
			
		||||
            acc = acc + fp32_value * rescale_factor;
 | 
			
		||||
          }
 | 
			
		||||
          v_load_vec_type cast_acc(acc);
 | 
			
		||||
          cast_acc.save(seq_head_output);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE)                 \
 | 
			
		||||
  paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call(   \
 | 
			
		||||
      out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr,         \
 | 
			
		||||
      key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
 | 
			
		||||
      seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,      \
 | 
			
		||||
      kv_block_stride, kv_head_stride, num_seqs, num_heads,                  \
 | 
			
		||||
      max_num_partitions);
 | 
			
		||||
 | 
			
		||||
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
 | 
			
		||||
void paged_attention_v2_impl_launcher(
 | 
			
		||||
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
 | 
			
		||||
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
 | 
			
		||||
    torch::Tensor& value_cache, int num_kv_heads, float scale,
 | 
			
		||||
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
 | 
			
		||||
    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
			
		||||
  int num_seqs = query.size(0);
 | 
			
		||||
  int num_heads = query.size(1);
 | 
			
		||||
  int head_size = query.size(2);
 | 
			
		||||
  int max_num_blocks_per_seq = block_tables.size(1);
 | 
			
		||||
  int q_stride = query.stride(0);
 | 
			
		||||
  int kv_block_stride = key_cache.stride(0);
 | 
			
		||||
  int kv_head_stride = key_cache.stride(1);
 | 
			
		||||
  int max_num_partitions = exp_sums.size(-1);
 | 
			
		||||
 | 
			
		||||
  // NOTE: alibi_slopes is optional.
 | 
			
		||||
  const float* alibi_slopes_ptr =
 | 
			
		||||
      alibi_slopes
 | 
			
		||||
          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
 | 
			
		||||
          : nullptr;
 | 
			
		||||
 | 
			
		||||
  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
 | 
			
		||||
  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
 | 
			
		||||
  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
 | 
			
		||||
  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
 | 
			
		||||
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
 | 
			
		||||
  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
 | 
			
		||||
  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
 | 
			
		||||
  int* block_tables_ptr = block_tables.data_ptr<int>();
 | 
			
		||||
  int* seq_lens_ptr = seq_lens.data_ptr<int>();
 | 
			
		||||
 | 
			
		||||
  switch (head_size) {
 | 
			
		||||
    case 64:
 | 
			
		||||
      LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 80:
 | 
			
		||||
      LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 96:
 | 
			
		||||
      LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 112:
 | 
			
		||||
      LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 128:
 | 
			
		||||
      LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 192:
 | 
			
		||||
      LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    case 256:
 | 
			
		||||
      LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported head size: ", head_size);
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE)                              \
 | 
			
		||||
  paged_attention_v2_impl_launcher<T, BLOCK_SIZE>(                          \
 | 
			
		||||
      out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,    \
 | 
			
		||||
      num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
 | 
			
		||||
      alibi_slopes);
 | 
			
		||||
 | 
			
		||||
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T)                     \
 | 
			
		||||
  switch (block_size) {                                           \
 | 
			
		||||
    case 16:                                                      \
 | 
			
		||||
      CALL_V2_KERNEL_LAUNCHER(T, 16);                             \
 | 
			
		||||
      break;                                                      \
 | 
			
		||||
    default:                                                      \
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
 | 
			
		||||
      break;                                                      \
 | 
			
		||||
  }
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void paged_attention_v2(
 | 
			
		||||
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
 | 
			
		||||
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
 | 
			
		||||
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
 | 
			
		||||
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
 | 
			
		||||
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
 | 
			
		||||
    const int64_t blocksparse_local_blocks,
 | 
			
		||||
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
 | 
			
		||||
    const int64_t blocksparse_head_sliding_step) {
 | 
			
		||||
  TORCH_CHECK(kv_scale == 1.0f);
 | 
			
		||||
  TORCH_CHECK(blocksparse_vert_stride <= 1,
 | 
			
		||||
              "CPU backend does not support blocksparse attention yet.");
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
 | 
			
		||||
                               [&] {
 | 
			
		||||
                                 CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
 | 
			
		||||
                                 CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t);
 | 
			
		||||
                                 CPU_KERNEL_GUARD_OUT(paged_attention_v2_impl)
 | 
			
		||||
                               });
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										137
									
								
								csrc/cpu/cache.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								csrc/cpu/cache.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,137 @@
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "cpu_types.hpp"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
 | 
			
		||||
                          std::vector<torch::Tensor> const& value_caches,
 | 
			
		||||
                          const torch::Tensor& mapping_pairs,
 | 
			
		||||
                          const int element_num_per_block,
 | 
			
		||||
                          const int layer_num) {
 | 
			
		||||
  const size_t pair_num = mapping_pairs.size(0);
 | 
			
		||||
  const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
 | 
			
		||||
#pragma omp parallel for collapse(2)
 | 
			
		||||
  for (int layer = 0; layer < layer_num; ++layer) {
 | 
			
		||||
    for (size_t pair = 0; pair < pair_num; ++pair) {
 | 
			
		||||
      int64_t source_offset =
 | 
			
		||||
          element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
 | 
			
		||||
      int64_t target_offset =
 | 
			
		||||
          element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
 | 
			
		||||
      scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
 | 
			
		||||
      scalar_t* source_ptr = key_cache_ptr + source_offset;
 | 
			
		||||
      scalar_t* target_ptr = key_cache_ptr + target_offset;
 | 
			
		||||
      std::memcpy(target_ptr, source_ptr, block_bytes);
 | 
			
		||||
 | 
			
		||||
      scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
 | 
			
		||||
      source_ptr = value_cache_ptr + source_offset;
 | 
			
		||||
      target_ptr = value_cache_ptr + target_offset;
 | 
			
		||||
      std::memcpy(target_ptr, source_ptr, block_bytes);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void reshape_and_cache_cpu_impl(
 | 
			
		||||
    const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
 | 
			
		||||
    scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
 | 
			
		||||
    const int64_t* __restrict__ slot_mapping, const int num_tokens,
 | 
			
		||||
    const int key_stride, const int value_stride, const int num_heads,
 | 
			
		||||
    const int head_size, const int block_size, const int x) {
 | 
			
		||||
  const int block_elem_num = num_heads * head_size * block_size;
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for collapse(2)
 | 
			
		||||
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
 | 
			
		||||
    for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
 | 
			
		||||
      const int64_t slot_idx = slot_mapping[token_idx];
 | 
			
		||||
      if (slot_idx >= 0) {
 | 
			
		||||
        int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
 | 
			
		||||
        int src_value_head_idx =
 | 
			
		||||
            token_idx * value_stride + head_idx * head_size;
 | 
			
		||||
        const scalar_t* src_key_head_ptr = key + src_key_head_idx;
 | 
			
		||||
        const scalar_t* src_value_head_ptr = value + src_value_head_idx;
 | 
			
		||||
        const int64_t block_index = slot_idx / block_size;
 | 
			
		||||
        const int64_t block_offset = slot_idx % block_size;
 | 
			
		||||
        scalar_t* target_key_head_ptr = key_cache +
 | 
			
		||||
                                        block_elem_num * block_index +
 | 
			
		||||
                                        head_idx * block_size * head_size;
 | 
			
		||||
        scalar_t* target_value_head_ptr = value_cache +
 | 
			
		||||
                                          block_elem_num * block_index +
 | 
			
		||||
                                          head_idx * block_size * head_size;
 | 
			
		||||
 | 
			
		||||
        for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
 | 
			
		||||
          const int64_t target_offset =
 | 
			
		||||
              src_key_idx * block_size + block_offset * x;
 | 
			
		||||
          for (int i = 0; i < x; ++i) {
 | 
			
		||||
            target_key_head_ptr[target_offset + i] =
 | 
			
		||||
                src_key_head_ptr[src_key_idx + i];
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for (int src_value_idx = 0; src_value_idx < head_size;
 | 
			
		||||
             ++src_value_idx) {
 | 
			
		||||
          const int64_t target_offset =
 | 
			
		||||
              src_value_idx * block_size + block_offset;
 | 
			
		||||
          target_value_head_ptr[target_offset] =
 | 
			
		||||
              src_value_head_ptr[src_value_idx];
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
};  // namespace
 | 
			
		||||
 | 
			
		||||
// Note: the key_caches and value_caches vectors are constant but
 | 
			
		||||
// not the Tensors they contain. The vectors need to be const refs
 | 
			
		||||
// in order to satisfy pytorch's C++ operator registration code.
 | 
			
		||||
void copy_blocks(std::vector<torch::Tensor> const& key_caches,
 | 
			
		||||
                 std::vector<torch::Tensor> const& value_caches,
 | 
			
		||||
                 const torch::Tensor& block_mapping) {
 | 
			
		||||
  unsigned num_layers = key_caches.size();
 | 
			
		||||
  TORCH_CHECK(num_layers == value_caches.size());
 | 
			
		||||
  if (num_layers == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int element_num_per_block = key_caches[0][0].numel();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
      key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
 | 
			
		||||
        CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
 | 
			
		||||
        copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
 | 
			
		||||
                                       element_num_per_block, num_layers);
 | 
			
		||||
        CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
 | 
			
		||||
                       torch::Tensor& key_cache, torch::Tensor& value_cache,
 | 
			
		||||
                       torch::Tensor& slot_mapping,
 | 
			
		||||
                       const std::string& kv_cache_dtype, double kv_scale) {
 | 
			
		||||
  TORCH_CHECK(kv_scale == 1.0f);
 | 
			
		||||
 | 
			
		||||
  int num_tokens = key.size(0);
 | 
			
		||||
  int num_heads = key.size(1);
 | 
			
		||||
  int head_size = key.size(2);
 | 
			
		||||
  int block_size = key_cache.size(3);
 | 
			
		||||
  int x = key_cache.size(4);
 | 
			
		||||
 | 
			
		||||
  int key_stride = key.stride(0);
 | 
			
		||||
  int value_stride = value.stride(0);
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
      key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
 | 
			
		||||
        CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
 | 
			
		||||
        reshape_and_cache_cpu_impl<scalar_t>(
 | 
			
		||||
            key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
 | 
			
		||||
            key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
 | 
			
		||||
            slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
 | 
			
		||||
            value_stride, num_heads, head_size, block_size, x);
 | 
			
		||||
        CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
 | 
			
		||||
                 const torch::Tensor& block_mapping) {
 | 
			
		||||
  TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										352
									
								
								csrc/cpu/cpu_types.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										352
									
								
								csrc/cpu/cpu_types.hpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,352 @@
 | 
			
		||||
 | 
			
		||||
#ifndef CPU_TYPES_HPP
 | 
			
		||||
#define CPU_TYPES_HPP
 | 
			
		||||
 | 
			
		||||
#include <immintrin.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
 | 
			
		||||
namespace vec_op {
 | 
			
		||||
 | 
			
		||||
// FIXME: FP16 is not fully supported in Torch-CPU
 | 
			
		||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)                                 \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)                         \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)                          \
 | 
			
		||||
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
 | 
			
		||||
 | 
			
		||||
#ifndef CPU_OP_GUARD
 | 
			
		||||
#define CPU_KERNEL_GUARD_IN(NAME)
 | 
			
		||||
#define CPU_KERNEL_GUARD_OUT(NAME)
 | 
			
		||||
#else
 | 
			
		||||
#define CPU_KERNEL_GUARD_IN(NAME)                                              \
 | 
			
		||||
  std::cout << #NAME << " invoked." << std::endl;
 | 
			
		||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#define FORCE_INLINE __attribute__((always_inline)) inline
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename T, T... indexes, typename F>
 | 
			
		||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
 | 
			
		||||
  (f(std::integral_constant<T, indexes>{}), ...);
 | 
			
		||||
}
 | 
			
		||||
}; // namespace
 | 
			
		||||
 | 
			
		||||
template <typename T, T count, typename F,
 | 
			
		||||
          typename = std::enable_if_t<std::is_invocable_v<F, T>>>
 | 
			
		||||
constexpr void unroll_loop(F &&f) {
 | 
			
		||||
  unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T> struct Vec {
 | 
			
		||||
  constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct FP32Vec8;
 | 
			
		||||
struct FP32Vec16;
 | 
			
		||||
 | 
			
		||||
#ifdef __AVX512FP16__
 | 
			
		||||
struct FP16Vec8 : public Vec<FP16Vec8> {
 | 
			
		||||
  constexpr static int VEC_ELEM_NUM = 8;
 | 
			
		||||
 | 
			
		||||
  __m128h reg;
 | 
			
		||||
 | 
			
		||||
  explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP16Vec8(__m128h data) : reg(data) {}
 | 
			
		||||
 | 
			
		||||
  FP16Vec8 operator*(const FP16Vec8 &b) const {
 | 
			
		||||
    return FP16Vec8(_mm_mul_ph(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP16Vec8 operator+(const FP16Vec8 &b) const {
 | 
			
		||||
    return FP16Vec8(_mm_add_ph(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP16Vec8 operator-(const FP16Vec8 &b) const {
 | 
			
		||||
    return FP16Vec8(_mm_sub_ph(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP16Vec8 operator/(const FP16Vec8 &b) const {
 | 
			
		||||
    return FP16Vec8(_mm_div_ph(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
 | 
			
		||||
};
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
struct BF16Vec8 : public Vec<BF16Vec8> {
 | 
			
		||||
  constexpr static int VEC_ELEM_NUM = 8;
 | 
			
		||||
 | 
			
		||||
  __m128i reg;
 | 
			
		||||
 | 
			
		||||
  explicit BF16Vec8(const void *ptr)
 | 
			
		||||
      : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
 | 
			
		||||
 | 
			
		||||
  explicit BF16Vec8(const FP32Vec8 &);
 | 
			
		||||
 | 
			
		||||
  void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct BF16Vec16 : public Vec<BF16Vec16> {
 | 
			
		||||
  constexpr static int VEC_ELEM_NUM = 16;
 | 
			
		||||
 | 
			
		||||
  __m256i reg;
 | 
			
		||||
 | 
			
		||||
  explicit BF16Vec16(const void *ptr)
 | 
			
		||||
      : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
 | 
			
		||||
 | 
			
		||||
  explicit BF16Vec16(const FP32Vec16 &);
 | 
			
		||||
 | 
			
		||||
  void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct BF16Vec32 : public Vec<BF16Vec32> {
 | 
			
		||||
  constexpr static int VEC_ELEM_NUM = 32;
 | 
			
		||||
 | 
			
		||||
  __m512i reg;
 | 
			
		||||
 | 
			
		||||
  explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
 | 
			
		||||
 | 
			
		||||
  explicit BF16Vec32(__m512i data) : reg(data) {}
 | 
			
		||||
 | 
			
		||||
  explicit BF16Vec32(BF16Vec8 &vec8_data)
 | 
			
		||||
      : reg((__m512i)_mm512_inserti32x4(
 | 
			
		||||
            _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
 | 
			
		||||
                                                      (__m128i)vec8_data.reg),
 | 
			
		||||
                                                  (__m128i)vec8_data.reg, 1),
 | 
			
		||||
                               (__m128i)vec8_data.reg, 2),
 | 
			
		||||
            (__m128i)vec8_data.reg, 3)) {}
 | 
			
		||||
 | 
			
		||||
  void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct FP32Vec4 : public Vec<FP32Vec4> {
 | 
			
		||||
  constexpr static int VEC_ELEM_NUM = 4;
 | 
			
		||||
  union AliasReg {
 | 
			
		||||
    __m128 reg;
 | 
			
		||||
    float values[VEC_ELEM_NUM];
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  __m128 reg;
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec4(__m128 data) : reg(data) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct FP32Vec8 : public Vec<FP32Vec8> {
 | 
			
		||||
  constexpr static int VEC_ELEM_NUM = 8;
 | 
			
		||||
  union AliasReg {
 | 
			
		||||
    __m256 reg;
 | 
			
		||||
    float values[VEC_ELEM_NUM];
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  __m256 reg;
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec8(__m256 data) : reg(data) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
 | 
			
		||||
 | 
			
		||||
#ifdef __AVX512FP16__
 | 
			
		||||
  explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec8(const BF16Vec8 &v)
 | 
			
		||||
      : reg(_mm256_castsi256_ps(
 | 
			
		||||
            _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
 | 
			
		||||
 | 
			
		||||
  float reduce_sum() const {
 | 
			
		||||
    AliasReg ar;
 | 
			
		||||
    ar.reg = reg;
 | 
			
		||||
    float result = 0;
 | 
			
		||||
    unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
 | 
			
		||||
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec8 exp() const {
 | 
			
		||||
    AliasReg ar;
 | 
			
		||||
    ar.reg = reg;
 | 
			
		||||
    return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
 | 
			
		||||
                                  expf(ar.values[5]), expf(ar.values[4]),
 | 
			
		||||
                                  expf(ar.values[3]), expf(ar.values[2]),
 | 
			
		||||
                                  expf(ar.values[1]), expf(ar.values[0])));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec8 tanh() const {
 | 
			
		||||
    AliasReg ar;
 | 
			
		||||
    ar.reg = reg;
 | 
			
		||||
    return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
 | 
			
		||||
                                  tanhf(ar.values[5]), tanhf(ar.values[4]),
 | 
			
		||||
                                  tanhf(ar.values[3]), tanhf(ar.values[2]),
 | 
			
		||||
                                  tanhf(ar.values[1]), tanhf(ar.values[0])));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec8 er() const {
 | 
			
		||||
    AliasReg ar;
 | 
			
		||||
    ar.reg = reg;
 | 
			
		||||
    return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
 | 
			
		||||
                                  erf(ar.values[5]), erf(ar.values[4]),
 | 
			
		||||
                                  erf(ar.values[3]), erf(ar.values[2]),
 | 
			
		||||
                                  erf(ar.values[1]), erf(ar.values[0])));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec8 operator*(const FP32Vec8 &b) const {
 | 
			
		||||
    return FP32Vec8(_mm256_mul_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec8 operator+(const FP32Vec8 &b) const {
 | 
			
		||||
    return FP32Vec8(_mm256_add_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec8 operator-(const FP32Vec8 &b) const {
 | 
			
		||||
    return FP32Vec8(_mm256_sub_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec8 operator/(const FP32Vec8 &b) const {
 | 
			
		||||
    return FP32Vec8(_mm256_div_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct FP32Vec16 : public Vec<FP32Vec16> {
 | 
			
		||||
  constexpr static int VEC_ELEM_NUM = 16;
 | 
			
		||||
  union AliasReg {
 | 
			
		||||
    __m512 reg;
 | 
			
		||||
    float values[VEC_ELEM_NUM];
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  __m512 reg;
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(__m512 data) : reg(data) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(const FP32Vec4 &data)
 | 
			
		||||
      : reg((__m512)_mm512_inserti32x4(
 | 
			
		||||
            _mm512_inserti32x4(
 | 
			
		||||
                _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
 | 
			
		||||
                                   (__m128i)data.reg, 1),
 | 
			
		||||
                (__m128i)data.reg, 2),
 | 
			
		||||
            (__m128i)data.reg, 3)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(const FP32Vec8 &data)
 | 
			
		||||
      : reg((__m512)_mm512_inserti32x8(
 | 
			
		||||
            _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(const BF16Vec16 &v)
 | 
			
		||||
      : reg(_mm512_castsi512_ps(
 | 
			
		||||
            _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
 | 
			
		||||
 | 
			
		||||
  explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
 | 
			
		||||
 | 
			
		||||
  FP32Vec16 operator*(const FP32Vec16 &b) const {
 | 
			
		||||
    return FP32Vec16(_mm512_mul_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec16 operator+(const FP32Vec16 &b) const {
 | 
			
		||||
    return FP32Vec16(_mm512_add_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec16 operator-(const FP32Vec16 &b) const {
 | 
			
		||||
    return FP32Vec16(_mm512_sub_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  FP32Vec16 operator/(const FP32Vec16 &b) const {
 | 
			
		||||
    return FP32Vec16(_mm512_div_ps(reg, b.reg));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
 | 
			
		||||
 | 
			
		||||
  template <int group_size> float reduce_sub_sum(int idx) {
 | 
			
		||||
    static_assert(VEC_ELEM_NUM % group_size == 0);
 | 
			
		||||
    constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
 | 
			
		||||
    __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
 | 
			
		||||
    return _mm512_mask_reduce_add_ps(mask, reg);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename T> struct VecType { using vec_type = void; };
 | 
			
		||||
 | 
			
		||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
 | 
			
		||||
 | 
			
		||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
 | 
			
		||||
 | 
			
		||||
#ifdef __AVX512FP16__
 | 
			
		||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
 | 
			
		||||
 | 
			
		||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
 | 
			
		||||
 | 
			
		||||
#ifdef __AVX512FP16__
 | 
			
		||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
 | 
			
		||||
  *reinterpret_cast<_Float16 *>(ptr) = v;
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
 | 
			
		||||
  acc = acc + a * b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef __AVX512BF16__
 | 
			
		||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
 | 
			
		||||
  *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
 | 
			
		||||
    : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
 | 
			
		||||
 | 
			
		||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
 | 
			
		||||
    : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
 | 
			
		||||
 | 
			
		||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
 | 
			
		||||
  acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
 | 
			
		||||
}
 | 
			
		||||
#else
 | 
			
		||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
 | 
			
		||||
  c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
 | 
			
		||||
      reinterpret_cast<c10::BFloat16 *>(&v);
 | 
			
		||||
  *ptr = *(v_ptr + 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
 | 
			
		||||
    : reg(_mm256_cvtepi32_epi16(
 | 
			
		||||
          _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
 | 
			
		||||
 | 
			
		||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
 | 
			
		||||
    : reg(_mm512_cvtepi32_epi16(
 | 
			
		||||
          _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
 | 
			
		||||
 | 
			
		||||
}; // namespace vec_op
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
							
								
								
									
										117
									
								
								csrc/cpu/layernorm.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								csrc/cpu/layernorm.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,117 @@
 | 
			
		||||
#include "cpu_types.hpp"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void rms_norm_impl(scalar_t* __restrict__ out,
 | 
			
		||||
                   const scalar_t* __restrict__ input,
 | 
			
		||||
                   const scalar_t* __restrict__ weight, const float epsilon,
 | 
			
		||||
                   const int num_tokens, const int hidden_size) {
 | 
			
		||||
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
 | 
			
		||||
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
 | 
			
		||||
  TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for
 | 
			
		||||
  for (int i = 0; i < num_tokens; ++i) {
 | 
			
		||||
    vec_op::FP32Vec8 variance(0.0);
 | 
			
		||||
    auto input_p = input + i * hidden_size;
 | 
			
		||||
    auto output_p = out + i * hidden_size;
 | 
			
		||||
    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
 | 
			
		||||
      scalar_vec_t x(input_p + j);
 | 
			
		||||
      vec_op::FP32Vec8 fp32_x(x);
 | 
			
		||||
      variance = variance + fp32_x * fp32_x;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    float s_variance =
 | 
			
		||||
        1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
 | 
			
		||||
    vec_op::FP32Vec8 fp32_s_variance(s_variance);
 | 
			
		||||
 | 
			
		||||
    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
 | 
			
		||||
      scalar_vec_t x(input_p + j);
 | 
			
		||||
      scalar_vec_t w(weight + j);
 | 
			
		||||
 | 
			
		||||
      vec_op::FP32Vec8 fp32_x(x);
 | 
			
		||||
      vec_op::FP32Vec8 fp32_w(w);
 | 
			
		||||
 | 
			
		||||
      vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
 | 
			
		||||
 | 
			
		||||
      scalar_vec_t out(fp32_out);
 | 
			
		||||
      out.save(output_p + j);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
 | 
			
		||||
                             scalar_t* __restrict__ residual,
 | 
			
		||||
                             const scalar_t* __restrict__ weight,
 | 
			
		||||
                             const float epsilon, const int num_tokens,
 | 
			
		||||
                             const int hidden_size) {
 | 
			
		||||
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
 | 
			
		||||
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
 | 
			
		||||
  TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for
 | 
			
		||||
  for (int i = 0; i < num_tokens; ++i) {
 | 
			
		||||
    vec_op::FP32Vec8 variance(0.0);
 | 
			
		||||
    auto input_p = input + i * hidden_size;
 | 
			
		||||
    auto residual_p = residual + i * hidden_size;
 | 
			
		||||
    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
 | 
			
		||||
      scalar_vec_t x(input_p + j);
 | 
			
		||||
      scalar_vec_t res(residual_p + j);
 | 
			
		||||
      vec_op::FP32Vec8 fp32_x(x);
 | 
			
		||||
      vec_op::FP32Vec8 fp32_res(res);
 | 
			
		||||
 | 
			
		||||
      fp32_x = fp32_x + fp32_res;
 | 
			
		||||
      variance = variance + fp32_x * fp32_x;
 | 
			
		||||
      scalar_vec_t out(fp32_x);
 | 
			
		||||
      out.save(residual_p + j);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    float s_variance =
 | 
			
		||||
        1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
 | 
			
		||||
    vec_op::FP32Vec8 fp32_s_variance(s_variance);
 | 
			
		||||
 | 
			
		||||
    for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
 | 
			
		||||
      scalar_vec_t w(weight + j);
 | 
			
		||||
      scalar_vec_t res(residual_p + j);
 | 
			
		||||
 | 
			
		||||
      vec_op::FP32Vec8 fp32_w(w);
 | 
			
		||||
      vec_op::FP32Vec8 fp32_res(res);
 | 
			
		||||
 | 
			
		||||
      vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
 | 
			
		||||
 | 
			
		||||
      scalar_vec_t out(fp32_out);
 | 
			
		||||
      out.save(input_p + j);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
 | 
			
		||||
              double epsilon) {
 | 
			
		||||
  int hidden_size = input.size(-1);
 | 
			
		||||
  int num_tokens = input.numel() / hidden_size;
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
 | 
			
		||||
    CPU_KERNEL_GUARD_IN(rms_norm_impl)
 | 
			
		||||
    rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
 | 
			
		||||
                  weight.data_ptr<scalar_t>(), epsilon, num_tokens,
 | 
			
		||||
                  hidden_size);
 | 
			
		||||
    CPU_KERNEL_GUARD_OUT(rms_norm_impl)
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
 | 
			
		||||
                        torch::Tensor& weight, double epsilon) {
 | 
			
		||||
  int hidden_size = input.size(-1);
 | 
			
		||||
  int num_tokens = input.numel() / hidden_size;
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
      input.scalar_type(), "fused_add_rms_norm_impl", [&] {
 | 
			
		||||
        CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
 | 
			
		||||
        fused_add_rms_norm_impl(
 | 
			
		||||
            input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
 | 
			
		||||
            weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
 | 
			
		||||
        CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										199
									
								
								csrc/cpu/pos_encoding.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										199
									
								
								csrc/cpu/pos_encoding.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,199 @@
 | 
			
		||||
 | 
			
		||||
#include "cpu_types.hpp"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void rotary_embedding_impl(
 | 
			
		||||
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
 | 
			
		||||
                                            // [num_tokens]
 | 
			
		||||
    scalar_t* __restrict__ query,           /// [batch_size, seq_len, num_heads,
 | 
			
		||||
                                   /// head_size] or [num_tokens, num_heads,
 | 
			
		||||
                                   /// head_size]
 | 
			
		||||
    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
 | 
			
		||||
                                 // head_size] or [num_tokens, num_kv_heads,
 | 
			
		||||
                                 // head_size]
 | 
			
		||||
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
 | 
			
		||||
                                                 // 2]
 | 
			
		||||
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
 | 
			
		||||
    const int num_heads, const int num_kv_heads, const int head_size,
 | 
			
		||||
    const int num_tokens) {
 | 
			
		||||
  using scalar_vec_t = vec_op::vec_t<scalar_t>;
 | 
			
		||||
  constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
 | 
			
		||||
 | 
			
		||||
  const int embed_dim = rot_dim / 2;
 | 
			
		||||
  bool flag = (embed_dim % VEC_ELEM_NUM == 0);
 | 
			
		||||
  const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
 | 
			
		||||
 | 
			
		||||
  auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
 | 
			
		||||
                          scalar_t* qk) {
 | 
			
		||||
    int j = 0;
 | 
			
		||||
    for (; j < loop_upper; j += VEC_ELEM_NUM) {
 | 
			
		||||
      const int rot_offset = j;
 | 
			
		||||
      const int x_index = rot_offset;
 | 
			
		||||
      const int y_index = embed_dim + rot_offset;
 | 
			
		||||
 | 
			
		||||
      const int64_t out_x = token_head + x_index;
 | 
			
		||||
      const int64_t out_y = token_head + y_index;
 | 
			
		||||
 | 
			
		||||
      const scalar_vec_t cos(cache_ptr + x_index);
 | 
			
		||||
      const scalar_vec_t sin(cache_ptr + y_index);
 | 
			
		||||
 | 
			
		||||
      const scalar_vec_t q_x(qk + out_x);
 | 
			
		||||
      const scalar_vec_t q_y(qk + out_y);
 | 
			
		||||
 | 
			
		||||
      vec_op::FP32Vec8 fp32_cos(cos);
 | 
			
		||||
      vec_op::FP32Vec8 fp32_sin(sin);
 | 
			
		||||
 | 
			
		||||
      vec_op::FP32Vec8 fp32_q_x(q_x);
 | 
			
		||||
      vec_op::FP32Vec8 fp32_q_y(q_y);
 | 
			
		||||
 | 
			
		||||
      auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
 | 
			
		||||
      scalar_vec_t(out1).save(qk + out_x);
 | 
			
		||||
 | 
			
		||||
      auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
 | 
			
		||||
      scalar_vec_t(out2).save(qk + out_y);
 | 
			
		||||
    }
 | 
			
		||||
    if (!flag) {
 | 
			
		||||
      for (; j < embed_dim; ++j) {
 | 
			
		||||
        const int x_index = j;
 | 
			
		||||
        const int y_index = embed_dim + j;
 | 
			
		||||
 | 
			
		||||
        const int64_t out_x = token_head + x_index;
 | 
			
		||||
        const int64_t out_y = token_head + y_index;
 | 
			
		||||
 | 
			
		||||
        const float fp32_cos = cache_ptr[x_index];
 | 
			
		||||
        const float fp32_sin = cache_ptr[y_index];
 | 
			
		||||
 | 
			
		||||
        const float fp32_q_x = qk[out_x];
 | 
			
		||||
        const float fp32_q_y = qk[out_y];
 | 
			
		||||
 | 
			
		||||
        qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
 | 
			
		||||
        qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for
 | 
			
		||||
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
 | 
			
		||||
    int64_t pos = positions[token_idx];
 | 
			
		||||
    const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 | 
			
		||||
 | 
			
		||||
    for (int i = 0; i < num_heads; ++i) {
 | 
			
		||||
      const int head_idx = i;
 | 
			
		||||
      const int64_t token_head =
 | 
			
		||||
          token_idx * query_stride + head_idx * head_size;
 | 
			
		||||
      compute_loop(token_head, cache_ptr, query);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int i = 0; i < num_kv_heads; ++i) {
 | 
			
		||||
      const int head_idx = i;
 | 
			
		||||
      const int64_t token_head = token_idx * key_stride + head_idx * head_size;
 | 
			
		||||
      compute_loop(token_head, cache_ptr, key);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
void rotary_embedding_gptj_impl(
 | 
			
		||||
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
 | 
			
		||||
                                            // [num_tokens]
 | 
			
		||||
    scalar_t* __restrict__ query,           /// [batch_size, seq_len, num_heads,
 | 
			
		||||
                                   /// head_size] or [num_tokens, num_heads,
 | 
			
		||||
                                   /// head_size]
 | 
			
		||||
    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
 | 
			
		||||
                                 // head_size] or [num_tokens, num_kv_heads,
 | 
			
		||||
                                 // head_size]
 | 
			
		||||
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
 | 
			
		||||
                                                 // 2]
 | 
			
		||||
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
 | 
			
		||||
    const int num_heads, const int num_kv_heads, const int head_size,
 | 
			
		||||
    const int num_tokens) {
 | 
			
		||||
  const int embed_dim = rot_dim / 2;
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for collapse(2)
 | 
			
		||||
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
 | 
			
		||||
    for (int i = 0; i < num_heads; ++i) {
 | 
			
		||||
      int64_t pos = positions[token_idx];
 | 
			
		||||
      const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 | 
			
		||||
      const scalar_t* cos_cache_ptr = cache_ptr;
 | 
			
		||||
      const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
 | 
			
		||||
      const int head_idx = i;
 | 
			
		||||
      const int64_t token_head =
 | 
			
		||||
          token_idx * query_stride + head_idx * head_size;
 | 
			
		||||
      scalar_t* head_query = token_head + query;
 | 
			
		||||
      for (int j = 0; j < embed_dim; j += 1) {
 | 
			
		||||
        const int rot_offset = j;
 | 
			
		||||
        const int x_index = 2 * rot_offset;
 | 
			
		||||
        const int y_index = 2 * rot_offset + 1;
 | 
			
		||||
 | 
			
		||||
        const float cos = cos_cache_ptr[rot_offset];
 | 
			
		||||
        const float sin = sin_cache_ptr[rot_offset];
 | 
			
		||||
 | 
			
		||||
        const float x = head_query[x_index];
 | 
			
		||||
        const float y = head_query[y_index];
 | 
			
		||||
 | 
			
		||||
        head_query[x_index] = x * cos - y * sin;
 | 
			
		||||
        head_query[y_index] = y * cos + x * sin;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#pragma omp parallel for collapse(2)
 | 
			
		||||
  for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
 | 
			
		||||
    for (int i = 0; i < num_kv_heads; ++i) {
 | 
			
		||||
      int64_t pos = positions[token_idx];
 | 
			
		||||
      const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 | 
			
		||||
      const scalar_t* cos_cache_ptr = cache_ptr;
 | 
			
		||||
      const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
 | 
			
		||||
      const int head_idx = i;
 | 
			
		||||
      const int64_t token_head = token_idx * key_stride + head_idx * head_size;
 | 
			
		||||
      scalar_t* head_key = key + token_head;
 | 
			
		||||
      for (int j = 0; j < embed_dim; j += 1) {
 | 
			
		||||
        const int rot_offset = j;
 | 
			
		||||
        const int x_index = 2 * rot_offset;
 | 
			
		||||
        const int y_index = 2 * rot_offset + 1;
 | 
			
		||||
 | 
			
		||||
        const float cos = cos_cache_ptr[rot_offset];
 | 
			
		||||
        const float sin = sin_cache_ptr[rot_offset];
 | 
			
		||||
 | 
			
		||||
        const float x = head_key[x_index];
 | 
			
		||||
        const float y = head_key[y_index];
 | 
			
		||||
 | 
			
		||||
        head_key[x_index] = x * cos - y * sin;
 | 
			
		||||
        head_key[y_index] = y * cos + x * sin;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
};  // namespace
 | 
			
		||||
 | 
			
		||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
 | 
			
		||||
                      torch::Tensor& key, int64_t head_size,
 | 
			
		||||
                      torch::Tensor& cos_sin_cache, bool is_neox) {
 | 
			
		||||
  int num_tokens = query.numel() / query.size(-1);
 | 
			
		||||
  int rot_dim = cos_sin_cache.size(1);
 | 
			
		||||
  int num_heads = query.size(-1) / head_size;
 | 
			
		||||
  int num_kv_heads = key.size(-1) / head_size;
 | 
			
		||||
  int64_t key_stride = key.stride(-2);
 | 
			
		||||
  int64_t query_stride = query.stride(-2);
 | 
			
		||||
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
      query.scalar_type(), "rotary_embedding_impl", [&] {
 | 
			
		||||
        CPU_KERNEL_GUARD_IN(rotary_embedding_impl)
 | 
			
		||||
        if (is_neox) {
 | 
			
		||||
          rotary_embedding_impl(
 | 
			
		||||
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
 | 
			
		||||
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
              rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
 | 
			
		||||
              head_size, num_tokens);
 | 
			
		||||
        } else {
 | 
			
		||||
          rotary_embedding_gptj_impl(
 | 
			
		||||
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
 | 
			
		||||
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
              rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
 | 
			
		||||
              head_size, num_tokens);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										106
									
								
								csrc/cpu/torch_bindings.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								csrc/cpu/torch_bindings.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,106 @@
 | 
			
		||||
#include "cache.h"
 | 
			
		||||
#include "ops.h"
 | 
			
		||||
#include "registration.h"
 | 
			
		||||
 | 
			
		||||
#include <torch/library.h>
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
 | 
			
		||||
  // vLLM custom ops
 | 
			
		||||
 | 
			
		||||
  // Attention ops
 | 
			
		||||
  // Compute the attention between an input query and the cached keys/values
 | 
			
		||||
  // using PagedAttention.
 | 
			
		||||
  ops.def(
 | 
			
		||||
      "paged_attention_v1("
 | 
			
		||||
      "    Tensor! out, Tensor query, Tensor key_cache,"
 | 
			
		||||
      "    Tensor value_cache, int num_kv_heads, float scale,"
 | 
			
		||||
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
 | 
			
		||||
      "    int max_seq_len, Tensor? alibi_slopes,"
 | 
			
		||||
      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
 | 
			
		||||
      "    int blocksparse_local_blocks,"
 | 
			
		||||
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
 | 
			
		||||
      "    int blocksparse_head_sliding_step) -> ()");
 | 
			
		||||
  ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
 | 
			
		||||
 | 
			
		||||
  // PagedAttention V2.
 | 
			
		||||
  ops.def(
 | 
			
		||||
      "paged_attention_v2("
 | 
			
		||||
      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
 | 
			
		||||
      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
 | 
			
		||||
      "    Tensor value_cache, int num_kv_heads, float scale,"
 | 
			
		||||
      "    Tensor block_tables, Tensor seq_lens, int block_size,"
 | 
			
		||||
      "    int max_seq_len, Tensor? alibi_slopes,"
 | 
			
		||||
      "    str kv_cache_dtype, float kv_scale, int tp_rank,"
 | 
			
		||||
      "    int blocksparse_local_blocks,"
 | 
			
		||||
      "    int blocksparse_vert_stride, int blocksparse_block_size,"
 | 
			
		||||
      "    int blocksparse_head_sliding_step) -> ()");
 | 
			
		||||
  ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
 | 
			
		||||
 | 
			
		||||
  // Activation ops
 | 
			
		||||
 | 
			
		||||
  // Activation function used in SwiGLU.
 | 
			
		||||
  ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
 | 
			
		||||
  ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul);
 | 
			
		||||
 | 
			
		||||
  // Activation function used in GeGLU with `none` approximation.
 | 
			
		||||
  ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
 | 
			
		||||
  ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul);
 | 
			
		||||
 | 
			
		||||
  // Activation function used in GeGLU with `tanh` approximation.
 | 
			
		||||
  ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
 | 
			
		||||
  ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul);
 | 
			
		||||
 | 
			
		||||
  // GELU implementation used in GPT-2.
 | 
			
		||||
  ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
 | 
			
		||||
  ops.impl("gelu_new", torch::kCPU, &gelu_new);
 | 
			
		||||
 | 
			
		||||
  // Approximate GELU implementation.
 | 
			
		||||
  ops.def("gelu_fast(Tensor! out, Tensor input) -> ()");
 | 
			
		||||
  ops.impl("gelu_fast", torch::kCPU, &gelu_fast);
 | 
			
		||||
 | 
			
		||||
  // Layernorm
 | 
			
		||||
  // Apply Root Mean Square (RMS) Normalization to the input tensor.
 | 
			
		||||
  ops.def(
 | 
			
		||||
      "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
 | 
			
		||||
      "()");
 | 
			
		||||
  ops.impl("rms_norm", torch::kCPU, &rms_norm);
 | 
			
		||||
 | 
			
		||||
  // In-place fused Add and RMS Normalization.
 | 
			
		||||
  ops.def(
 | 
			
		||||
      "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, "
 | 
			
		||||
      "float epsilon) -> ()");
 | 
			
		||||
  ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm);
 | 
			
		||||
 | 
			
		||||
  // Rotary embedding
 | 
			
		||||
  // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
 | 
			
		||||
  ops.def(
 | 
			
		||||
      "rotary_embedding(Tensor positions, Tensor! query,"
 | 
			
		||||
      "                 Tensor! key, int head_size,"
 | 
			
		||||
      "                 Tensor cos_sin_cache, bool is_neox) -> ()");
 | 
			
		||||
  ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
 | 
			
		||||
  // Cache ops
 | 
			
		||||
  // Swap in (out) the cache blocks from src to dst.
 | 
			
		||||
  cache_ops.def(
 | 
			
		||||
      "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
 | 
			
		||||
  cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks);
 | 
			
		||||
 | 
			
		||||
  // Copy the cache blocks from src to dst.
 | 
			
		||||
  cache_ops.def(
 | 
			
		||||
      "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
 | 
			
		||||
      "block_mapping) -> ()");
 | 
			
		||||
  cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks);
 | 
			
		||||
 | 
			
		||||
  // Reshape the key and value tensors and cache them.
 | 
			
		||||
  cache_ops.def(
 | 
			
		||||
      "reshape_and_cache(Tensor key, Tensor value,"
 | 
			
		||||
      "                  Tensor! key_cache, Tensor! value_cache,"
 | 
			
		||||
      "                  Tensor slot_mapping,"
 | 
			
		||||
      "                  str kv_cache_dtype,"
 | 
			
		||||
      "                  float kv_scale) -> ()");
 | 
			
		||||
  cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
#include <hip/hip_runtime.h>
 | 
			
		||||
  #include <hip/hip_runtime.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
@ -17,9 +17,14 @@
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
 | 
			
		||||
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
 | 
			
		||||
    __shfl_xor_sync(uint32_t(-1), var, lane_mask)
 | 
			
		||||
  #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
 | 
			
		||||
    __shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
 | 
			
		||||
#else
 | 
			
		||||
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
 | 
			
		||||
  #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
 | 
			
		||||
    __shfl_xor(var, lane_mask, width)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
@ -28,6 +33,13 @@
 | 
			
		||||
  #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
 | 
			
		||||
    __shfl_down_sync(uint32_t(-1), var, lane_delta)
 | 
			
		||||
#else
 | 
			
		||||
  #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
 | 
			
		||||
    cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
 | 
			
		||||
@ -35,4 +47,3 @@
 | 
			
		||||
  #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
 | 
			
		||||
    hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,5 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
 | 
			
		||||
 | 
			
		||||
int get_device_attribute(
 | 
			
		||||
    int attribute,
 | 
			
		||||
    int device_id);
 | 
			
		||||
 | 
			
		||||
int get_max_shared_memory_per_block_device_attribute(
 | 
			
		||||
    int device_id);
 | 
			
		||||
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
 | 
			
		||||
 | 
			
		||||
@ -2,34 +2,28 @@
 | 
			
		||||
  #include <hip/hip_runtime.h>
 | 
			
		||||
  #include <hip/hip_runtime_api.h>
 | 
			
		||||
#endif
 | 
			
		||||
int get_device_attribute(
 | 
			
		||||
    int attribute,
 | 
			
		||||
    int device_id)
 | 
			
		||||
{
 | 
			
		||||
    int device, value;
 | 
			
		||||
    if (device_id < 0) {
 | 
			
		||||
        cudaGetDevice(&device);
 | 
			
		||||
    }
 | 
			
		||||
    else {
 | 
			
		||||
        device = device_id;
 | 
			
		||||
    }
 | 
			
		||||
    cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
 | 
			
		||||
    return value;
 | 
			
		||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
 | 
			
		||||
  int device, value;
 | 
			
		||||
  if (device_id < 0) {
 | 
			
		||||
    cudaGetDevice(&device);
 | 
			
		||||
  } else {
 | 
			
		||||
    device = device_id;
 | 
			
		||||
  }
 | 
			
		||||
  cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
 | 
			
		||||
                         device);
 | 
			
		||||
  return value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
int get_max_shared_memory_per_block_device_attribute(
 | 
			
		||||
    int device_id)
 | 
			
		||||
{
 | 
			
		||||
int attribute;    
 | 
			
		||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
 | 
			
		||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
 | 
			
		||||
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) {
 | 
			
		||||
  int64_t attribute;
 | 
			
		||||
  // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
 | 
			
		||||
  // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
 | 
			
		||||
  attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
 | 
			
		||||
#else
 | 
			
		||||
    attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
 | 
			
		||||
  attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    return get_device_attribute(attribute, device_id);
 | 
			
		||||
  return get_device_attribute(attribute, device_id);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,17 +1,17 @@
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <c10/cuda/CUDAStream.h>
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
 | 
			
		||||
#include "custom_all_reduce.cuh"
 | 
			
		||||
 | 
			
		||||
// fake pointer type
 | 
			
		||||
using fptr_t = uint64_t;
 | 
			
		||||
static_assert(sizeof(void *) == sizeof(fptr_t));
 | 
			
		||||
// fake pointer type, must match fptr_t type in ops.h
 | 
			
		||||
using fptr_t = int64_t;
 | 
			
		||||
static_assert(sizeof(void*) == sizeof(fptr_t));
 | 
			
		||||
 | 
			
		||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
 | 
			
		||||
                      const std::vector<std::string> &handles,
 | 
			
		||||
                      const std::vector<int64_t> &offsets, int rank,
 | 
			
		||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
 | 
			
		||||
                      const std::vector<std::string>& handles,
 | 
			
		||||
                      const std::vector<int64_t>& offsets, int64_t rank,
 | 
			
		||||
                      bool full_nvlink) {
 | 
			
		||||
  int world_size = offsets.size();
 | 
			
		||||
  if (world_size > 8)
 | 
			
		||||
@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
 | 
			
		||||
    std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
 | 
			
		||||
  }
 | 
			
		||||
  return (fptr_t) new vllm::CustomAllreduce(
 | 
			
		||||
      reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
 | 
			
		||||
      reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
 | 
			
		||||
      rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
 | 
			
		||||
 * 5. A[None].expand(2, -1, -1, -1): Not OK
 | 
			
		||||
 * 6. A[:, 1:, 1:]: Not OK
 | 
			
		||||
 */
 | 
			
		||||
bool _is_weak_contiguous(torch::Tensor &t) {
 | 
			
		||||
bool _is_weak_contiguous(torch::Tensor& t) {
 | 
			
		||||
  return t.is_contiguous() ||
 | 
			
		||||
         (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
 | 
			
		||||
          t.numel() * t.element_size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
 | 
			
		||||
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
 | 
			
		||||
                      bool full_nvlink) {
 | 
			
		||||
  auto inp_size = inp.numel() * inp.element_size();
 | 
			
		||||
  // custom allreduce requires input byte size to be multiples of 16
 | 
			
		||||
@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
 | 
			
		||||
void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
 | 
			
		||||
                 cudaStream_t stream) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
 | 
			
		||||
  TORCH_CHECK(_is_weak_contiguous(out));
 | 
			
		||||
  switch (out.scalar_type()) {
 | 
			
		||||
    case at::ScalarType::Float: {
 | 
			
		||||
      fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
 | 
			
		||||
                           reinterpret_cast<float *>(out.data_ptr()),
 | 
			
		||||
      fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
 | 
			
		||||
                           reinterpret_cast<float*>(out.data_ptr()),
 | 
			
		||||
                           out.numel());
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case at::ScalarType::Half: {
 | 
			
		||||
      fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
 | 
			
		||||
                          reinterpret_cast<half *>(out.data_ptr()),
 | 
			
		||||
                          out.numel());
 | 
			
		||||
      fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
 | 
			
		||||
                          reinterpret_cast<half*>(out.data_ptr()), out.numel());
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
 | 
			
		||||
    case at::ScalarType::BFloat16: {
 | 
			
		||||
      fa->allreduce<nv_bfloat16>(
 | 
			
		||||
          stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
 | 
			
		||||
          reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
 | 
			
		||||
          stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
 | 
			
		||||
          reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
 | 
			
		||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
 | 
			
		||||
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
 | 
			
		||||
  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
 | 
			
		||||
@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
 | 
			
		||||
  _all_reduce(_fa, inp, out, stream);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
 | 
			
		||||
                      torch::Tensor &out) {
 | 
			
		||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
 | 
			
		||||
                      torch::Tensor& out) {
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
 | 
			
		||||
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
 | 
			
		||||
 | 
			
		||||
@ -122,27 +121,33 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void dispose(fptr_t _fa) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
 | 
			
		||||
  delete fa;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int meta_size() { return sizeof(vllm::Signal); }
 | 
			
		||||
int64_t meta_size() { return sizeof(vllm::Signal); }
 | 
			
		||||
 | 
			
		||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
 | 
			
		||||
                     const std::vector<std::string> &handles,
 | 
			
		||||
                     const std::vector<int64_t> &offsets) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
 | 
			
		||||
                     const std::vector<std::string>& handles,
 | 
			
		||||
                     const std::vector<int64_t>& offsets) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
 | 
			
		||||
  fa->register_buffer(handles, offsets, t.data_ptr());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
 | 
			
		||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
 | 
			
		||||
    fptr_t _fa) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  return fa->get_graph_buffer_ipc_meta();
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
 | 
			
		||||
  auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta();
 | 
			
		||||
  auto options =
 | 
			
		||||
      torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
 | 
			
		||||
  auto handles =
 | 
			
		||||
      torch::empty({static_cast<int64_t>(handle_bytes.size())}, options);
 | 
			
		||||
  std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size());
 | 
			
		||||
  return {handles, std::move(offsets)};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
 | 
			
		||||
                            const std::vector<std::vector<int64_t>> &offsets) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
 | 
			
		||||
                            const std::vector<std::vector<int64_t>>& offsets) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
 | 
			
		||||
  fa->register_graph_buffers(handles, offsets);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -31,9 +31,9 @@ struct Signal {
 | 
			
		||||
  alignas(128) uint32_t end[kMaxBlocks][8];
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
 | 
			
		||||
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
 | 
			
		||||
 | 
			
		||||
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
 | 
			
		||||
struct __align__(16) RankSignals { volatile Signal* signals[8]; };
 | 
			
		||||
 | 
			
		||||
// like std::array, but aligned
 | 
			
		||||
template <typename T, int sz>
 | 
			
		||||
@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
 | 
			
		||||
// scalar add functions
 | 
			
		||||
// for some reason when compiling with Pytorch, the + operator for half and
 | 
			
		||||
// bfloat is disabled so we call the intrinsics directly
 | 
			
		||||
DINLINE half &assign_add(half &a, half b) {
 | 
			
		||||
DINLINE half& assign_add(half& a, half b) {
 | 
			
		||||
  a = __hadd(a, b);
 | 
			
		||||
  return a;
 | 
			
		||||
}
 | 
			
		||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
 | 
			
		||||
DINLINE float& assign_add(float& a, float b) { return a += b; }
 | 
			
		||||
 | 
			
		||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
 | 
			
		||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
 | 
			
		||||
@ -80,14 +80,14 @@ template <>
 | 
			
		||||
DINLINE nv_bfloat16 downcast_s(float val) {
 | 
			
		||||
  return __float2bfloat16(val);
 | 
			
		||||
}
 | 
			
		||||
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
 | 
			
		||||
DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
 | 
			
		||||
  a = __hadd(a, b);
 | 
			
		||||
  return a;
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename T, int N>
 | 
			
		||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
 | 
			
		||||
DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int i = 0; i < N; i++) {
 | 
			
		||||
    assign_add(a.data[i], b.data[i]);
 | 
			
		||||
@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
 | 
			
		||||
// prior memory accesses. Note: volatile writes will not be reordered against
 | 
			
		||||
// other volatile writes.
 | 
			
		||||
template <int ngpus>
 | 
			
		||||
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
 | 
			
		||||
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
 | 
			
		||||
                        int rank) {
 | 
			
		||||
  if (threadIdx.x < ngpus) {
 | 
			
		||||
    // reset flag for next time
 | 
			
		||||
@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
 | 
			
		||||
    // Latency = 1 p2p write
 | 
			
		||||
    sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
 | 
			
		||||
    // wait until we got true from all ranks
 | 
			
		||||
    while (!self_sg->start[blockIdx.x][threadIdx.x])
 | 
			
		||||
      ;
 | 
			
		||||
    while (!self_sg->start[blockIdx.x][threadIdx.x]);
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
}
 | 
			
		||||
@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
 | 
			
		||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
 | 
			
		||||
// we don't need to make any visibility guarantees for prior memory accesses.
 | 
			
		||||
template <int ngpus, bool final_sync = false>
 | 
			
		||||
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
 | 
			
		||||
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
 | 
			
		||||
                      int rank) {
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
  // eliminate the case that prior writes are not visible after signals become
 | 
			
		||||
  // visible. Note that I did not managed to make this happen through a lot of
 | 
			
		||||
  // testing. Might be the case that hardware provides stronger guarantee than
 | 
			
		||||
  // the memory model. 
 | 
			
		||||
  // the memory model.
 | 
			
		||||
  if constexpr (!final_sync) __threadfence_system();
 | 
			
		||||
  if (threadIdx.x < ngpus) {
 | 
			
		||||
    // reset flag for next time
 | 
			
		||||
@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
 | 
			
		||||
    // Latency = 1 p2p write
 | 
			
		||||
    sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
 | 
			
		||||
    // wait until we got true from all ranks
 | 
			
		||||
    while (!self_sg->end[blockIdx.x][threadIdx.x])
 | 
			
		||||
      ;
 | 
			
		||||
    while (!self_sg->end[blockIdx.x][threadIdx.x]);
 | 
			
		||||
  }
 | 
			
		||||
  if constexpr (!final_sync) __syncthreads();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename P, int ngpus, typename A>
 | 
			
		||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
 | 
			
		||||
DINLINE P packed_reduce(const P* ptrs[], int idx) {
 | 
			
		||||
  A tmp = upcast(ptrs[0][idx]);
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int i = 1; i < ngpus; i++) {
 | 
			
		||||
@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
 | 
			
		||||
 | 
			
		||||
template <typename T, int ngpus>
 | 
			
		||||
__global__ void __launch_bounds__(512, 1)
 | 
			
		||||
    cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
 | 
			
		||||
                               volatile Signal *self_sg, T *__restrict__ result,
 | 
			
		||||
    cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
 | 
			
		||||
                               volatile Signal* self_sg, T* __restrict__ result,
 | 
			
		||||
                               int rank, int size) {
 | 
			
		||||
  using P = typename packed_t<T>::P;
 | 
			
		||||
  using A = typename packed_t<T>::A;
 | 
			
		||||
@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
 | 
			
		||||
  // do the actual reduction
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    ((P *)result)[idx] =
 | 
			
		||||
        packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
 | 
			
		||||
    ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
 | 
			
		||||
  }
 | 
			
		||||
  end_sync<ngpus, true>(sg, self_sg, rank);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename P>
 | 
			
		||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
 | 
			
		||||
  return (P *)(((Signal *)sg) + 1);
 | 
			
		||||
DINLINE P* get_tmp_buf(volatile Signal* sg) {
 | 
			
		||||
  return (P*)(((Signal*)sg) + 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, int ngpus>
 | 
			
		||||
__global__ void __launch_bounds__(512, 1)
 | 
			
		||||
    cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
 | 
			
		||||
                               volatile Signal *self_sg, T *__restrict__ result,
 | 
			
		||||
    cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
 | 
			
		||||
                               volatile Signal* self_sg, T* __restrict__ result,
 | 
			
		||||
                               int rank, int size) {
 | 
			
		||||
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
  int stride = gridDim.x * blockDim.x;
 | 
			
		||||
@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
 | 
			
		||||
  int start = rank * part;
 | 
			
		||||
  int end = rank == ngpus - 1 ? size : start + part;
 | 
			
		||||
  int largest_part = part + size % ngpus;
 | 
			
		||||
  const P *ptrs[ngpus];
 | 
			
		||||
  P *tmps[ngpus];
 | 
			
		||||
  const P* ptrs[ngpus];
 | 
			
		||||
  P* tmps[ngpus];
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int i = 0; i < ngpus; i++) {
 | 
			
		||||
    int target = (rank + i) % ngpus;
 | 
			
		||||
    ptrs[i] = (const P *)_dp->ptrs[target];
 | 
			
		||||
    ptrs[i] = (const P*)_dp->ptrs[target];
 | 
			
		||||
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
 | 
			
		||||
  }
 | 
			
		||||
  auto tmp_out = tmps[0];
 | 
			
		||||
@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
 | 
			
		||||
      int gather_from_rank = ((rank + i) % ngpus);
 | 
			
		||||
      if (gather_from_rank == ngpus - 1 || idx < part) {
 | 
			
		||||
        int dst_idx = gather_from_rank * part + idx;
 | 
			
		||||
        ((P *)result)[dst_idx] = tmps[i][idx];
 | 
			
		||||
        ((P*)result)[dst_idx] = tmps[i][idx];
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
@ -261,14 +258,14 @@ class CustomAllreduce {
 | 
			
		||||
 | 
			
		||||
  // below are device pointers
 | 
			
		||||
  RankSignals sg_;
 | 
			
		||||
  std::unordered_map<void *, RankData *> buffers_;
 | 
			
		||||
  Signal *self_sg_;
 | 
			
		||||
  std::unordered_map<void*, RankData*> buffers_;
 | 
			
		||||
  Signal* self_sg_;
 | 
			
		||||
 | 
			
		||||
  // stores the registered device pointers from all ranks
 | 
			
		||||
  RankData *d_rank_data_base_, *d_rank_data_end_;
 | 
			
		||||
  std::vector<void *> graph_unreg_buffers_;
 | 
			
		||||
  std::vector<void*> graph_unreg_buffers_;
 | 
			
		||||
  // a map from IPC handles to opened IPC pointers
 | 
			
		||||
  std::map<IPC_KEY, char *> ipc_handles_;
 | 
			
		||||
  std::map<IPC_KEY, char*> ipc_handles_;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * meta is a pointer to device metadata and temporary buffer for allreduce.
 | 
			
		||||
@ -279,22 +276,22 @@ class CustomAllreduce {
 | 
			
		||||
   * note: this class does not own any device memory. Any required buffers
 | 
			
		||||
   * are passed in from the constructor
 | 
			
		||||
   */
 | 
			
		||||
  CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
 | 
			
		||||
                  const cudaIpcMemHandle_t *handles,
 | 
			
		||||
                  const std::vector<int64_t> &offsets, int rank,
 | 
			
		||||
  CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
 | 
			
		||||
                  const cudaIpcMemHandle_t* handles,
 | 
			
		||||
                  const std::vector<int64_t>& offsets, int rank,
 | 
			
		||||
                  bool full_nvlink = true)
 | 
			
		||||
      : rank_(rank),
 | 
			
		||||
        world_size_(offsets.size()),
 | 
			
		||||
        full_nvlink_(full_nvlink),
 | 
			
		||||
        self_sg_(meta),
 | 
			
		||||
        d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
 | 
			
		||||
        d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
 | 
			
		||||
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
 | 
			
		||||
    for (int i = 0; i < world_size_; i++) {
 | 
			
		||||
      Signal *rank_sg;
 | 
			
		||||
      Signal* rank_sg;
 | 
			
		||||
      if (i != rank_) {
 | 
			
		||||
        char *handle = open_ipc_handle(&handles[i]);
 | 
			
		||||
        char* handle = open_ipc_handle(&handles[i]);
 | 
			
		||||
        handle += offsets[i];
 | 
			
		||||
        rank_sg = (Signal *)handle;
 | 
			
		||||
        rank_sg = (Signal*)handle;
 | 
			
		||||
      } else {
 | 
			
		||||
        rank_sg = self_sg_;
 | 
			
		||||
      }
 | 
			
		||||
@ -302,13 +299,13 @@ class CustomAllreduce {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char *open_ipc_handle(const void *ipc_handle) {
 | 
			
		||||
  char* open_ipc_handle(const void* ipc_handle) {
 | 
			
		||||
    auto [it, new_handle] =
 | 
			
		||||
        ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
 | 
			
		||||
        ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
 | 
			
		||||
    if (new_handle) {
 | 
			
		||||
      char *ipc_ptr;
 | 
			
		||||
      CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
 | 
			
		||||
                                     *((const cudaIpcMemHandle_t *)ipc_handle),
 | 
			
		||||
      char* ipc_ptr;
 | 
			
		||||
      CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
 | 
			
		||||
                                     *((const cudaIpcMemHandle_t*)ipc_handle),
 | 
			
		||||
                                     cudaIpcMemLazyEnablePeerAccess));
 | 
			
		||||
      it->second = ipc_ptr;
 | 
			
		||||
    }
 | 
			
		||||
@ -323,7 +320,7 @@ class CustomAllreduce {
 | 
			
		||||
    std::vector<int64_t> offsets(num_buffers);
 | 
			
		||||
    for (int i = 0; i < num_buffers; i++) {
 | 
			
		||||
      auto ptr = graph_unreg_buffers_[i];
 | 
			
		||||
      void *base_ptr;
 | 
			
		||||
      void* base_ptr;
 | 
			
		||||
      // note: must share the base address of each allocation, or we get wrong
 | 
			
		||||
      // address
 | 
			
		||||
      if (cuPointerGetAttribute(&base_ptr,
 | 
			
		||||
@ -331,8 +328,8 @@ class CustomAllreduce {
 | 
			
		||||
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
 | 
			
		||||
        throw std::runtime_error("failed to get pointer attr");
 | 
			
		||||
      CUDACHECK(cudaIpcGetMemHandle(
 | 
			
		||||
          (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
 | 
			
		||||
      offsets[i] = ((char *)ptr) - ((char *)base_ptr);
 | 
			
		||||
          (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
 | 
			
		||||
      offsets[i] = ((char*)ptr) - ((char*)base_ptr);
 | 
			
		||||
    }
 | 
			
		||||
    return std::make_pair(handles, offsets);
 | 
			
		||||
  }
 | 
			
		||||
@ -344,13 +341,13 @@ class CustomAllreduce {
 | 
			
		||||
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void register_buffer(const std::vector<std::string> &handles,
 | 
			
		||||
                       const std::vector<int64_t> &offsets, void *self) {
 | 
			
		||||
  void register_buffer(const std::vector<std::string>& handles,
 | 
			
		||||
                       const std::vector<int64_t>& offsets, void* self) {
 | 
			
		||||
    check_rank_data_capacity();
 | 
			
		||||
    RankData data;
 | 
			
		||||
    for (int i = 0; i < world_size_; i++) {
 | 
			
		||||
      if (i != rank_) {
 | 
			
		||||
        char *handle = open_ipc_handle(handles[i].data());
 | 
			
		||||
        char* handle = open_ipc_handle(handles[i].data());
 | 
			
		||||
        handle += offsets[i];
 | 
			
		||||
        data.ptrs[i] = handle;
 | 
			
		||||
      } else {
 | 
			
		||||
@ -371,17 +368,17 @@ class CustomAllreduce {
 | 
			
		||||
  // got a different address. IPC handles have internal reference counting
 | 
			
		||||
  // mechanism so overhead should be small.
 | 
			
		||||
  void register_graph_buffers(
 | 
			
		||||
      const std::vector<std::string> &handles,
 | 
			
		||||
      const std::vector<std::vector<int64_t>> &offsets) {
 | 
			
		||||
      const std::vector<std::string>& handles,
 | 
			
		||||
      const std::vector<std::vector<int64_t>>& offsets) {
 | 
			
		||||
    auto num_buffers = graph_unreg_buffers_.size();
 | 
			
		||||
    check_rank_data_capacity(num_buffers);
 | 
			
		||||
    std::vector<RankData> rank_data(num_buffers);
 | 
			
		||||
    for (int i = 0; i < num_buffers; i++) {
 | 
			
		||||
      auto self_ptr = graph_unreg_buffers_[i];
 | 
			
		||||
      auto &rd = rank_data[i];
 | 
			
		||||
      auto& rd = rank_data[i];
 | 
			
		||||
      for (int j = 0; j < world_size_; j++) {
 | 
			
		||||
        if (j != rank_) {
 | 
			
		||||
          char *handle =
 | 
			
		||||
          char* handle =
 | 
			
		||||
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
 | 
			
		||||
          handle += offsets[j][i];
 | 
			
		||||
          rd.ptrs[j] = handle;
 | 
			
		||||
@ -405,7 +402,7 @@ class CustomAllreduce {
 | 
			
		||||
   * will cause contention on NVLink bus.
 | 
			
		||||
   */
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void allreduce(cudaStream_t stream, T *input, T *output, int size,
 | 
			
		||||
  void allreduce(cudaStream_t stream, T* input, T* output, int size,
 | 
			
		||||
                 int threads = 512, int block_limit = 36) {
 | 
			
		||||
    auto d = packed_t<T>::P::size;
 | 
			
		||||
    if (size % d != 0)
 | 
			
		||||
@ -418,7 +415,7 @@ class CustomAllreduce {
 | 
			
		||||
                               std::to_string(kMaxBlocks) + ". Got " +
 | 
			
		||||
                               std::to_string(block_limit));
 | 
			
		||||
 | 
			
		||||
    RankData *ptrs;
 | 
			
		||||
    RankData* ptrs;
 | 
			
		||||
    cudaStreamCaptureStatus status;
 | 
			
		||||
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
 | 
			
		||||
    if (status == cudaStreamCaptureStatusActive) {
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
__global__ void set_data(T *data, int size, int myRank) {
 | 
			
		||||
__global__ void set_data(T* data, int size, int myRank) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    data[idx] = myRank * 0.11f;
 | 
			
		||||
@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
 | 
			
		||||
                             double *fdata2, int size) {
 | 
			
		||||
__global__ void convert_data(const T* data1, const T* data2, double* fdata1,
 | 
			
		||||
                             double* fdata2, int size) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    fdata1[idx] = data1[idx];
 | 
			
		||||
@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
 | 
			
		||||
__global__ void init_rand(curandState_t* state, int size, int nRanks) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    for (int i = 0; i < nRanks; i++) {
 | 
			
		||||
@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
 | 
			
		||||
__global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
 | 
			
		||||
                         int myRank, int nRanks, int size) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
 | 
			
		||||
void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
 | 
			
		||||
         int data_size, bool performance_test) {
 | 
			
		||||
  T *result;
 | 
			
		||||
  T* result;
 | 
			
		||||
  cudaStream_t stream;
 | 
			
		||||
  CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
 | 
			
		||||
  CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
 | 
			
		||||
@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
 | 
			
		||||
 | 
			
		||||
  cudaIpcMemHandle_t self_data_handle;
 | 
			
		||||
  cudaIpcMemHandle_t data_handles[8];
 | 
			
		||||
  vllm::Signal *buffer;
 | 
			
		||||
  T *self_data_copy;
 | 
			
		||||
  vllm::Signal* buffer;
 | 
			
		||||
  T* self_data_copy;
 | 
			
		||||
  /**
 | 
			
		||||
   * Allocate IPC buffer
 | 
			
		||||
   *
 | 
			
		||||
@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
 | 
			
		||||
                         MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
 | 
			
		||||
                         MPI_BYTE, MPI_COMM_WORLD));
 | 
			
		||||
 | 
			
		||||
  void *rank_data;
 | 
			
		||||
  void* rank_data;
 | 
			
		||||
  size_t rank_data_sz = 16 * 1024 * 1024;
 | 
			
		||||
  CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
 | 
			
		||||
  std::vector<int64_t> offsets(nRanks, 0);
 | 
			
		||||
  vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
 | 
			
		||||
                           offsets, myRank);
 | 
			
		||||
  auto *self_data =
 | 
			
		||||
      reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
 | 
			
		||||
                            sizeof(vllm::Signal) + data_size * sizeof(T));
 | 
			
		||||
  auto* self_data =
 | 
			
		||||
      reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
 | 
			
		||||
                           sizeof(vllm::Signal) + data_size * sizeof(T));
 | 
			
		||||
  // hack buffer registration
 | 
			
		||||
  {
 | 
			
		||||
    std::vector<std::string> handles;
 | 
			
		||||
    handles.reserve(nRanks);
 | 
			
		||||
    for (int i = 0; i < nRanks; i++) {
 | 
			
		||||
      char *begin = (char *)&data_handles[i];
 | 
			
		||||
      char *end = (char *)&data_handles[i + 1];
 | 
			
		||||
      char* begin = (char*)&data_handles[i];
 | 
			
		||||
      char* end = (char*)&data_handles[i + 1];
 | 
			
		||||
      handles.emplace_back(begin, end);
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<int64_t> offsets(nRanks,
 | 
			
		||||
@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
 | 
			
		||||
    fa.register_buffer(handles, offsets, self_data);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  double *ground_truth;
 | 
			
		||||
  double* ground_truth;
 | 
			
		||||
  CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
 | 
			
		||||
  curandState_t *states;
 | 
			
		||||
  curandState_t* states;
 | 
			
		||||
  CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
 | 
			
		||||
  init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
 | 
			
		||||
  gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
 | 
			
		||||
@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
 | 
			
		||||
  CUDACHECK(cudaStreamDestroy(stream));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int main(int argc, char **argv) {
 | 
			
		||||
int main(int argc, char** argv) {
 | 
			
		||||
  int nRanks, myRank;
 | 
			
		||||
  MPICHECK(MPI_Init(&argc, &argv));
 | 
			
		||||
  MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
 | 
			
		||||
@ -296,7 +296,7 @@ int main(int argc, char **argv) {
 | 
			
		||||
  ncclUniqueId id;
 | 
			
		||||
  ncclComm_t comm;
 | 
			
		||||
  if (myRank == 0) ncclGetUniqueId(&id);
 | 
			
		||||
  MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
 | 
			
		||||
  MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
 | 
			
		||||
                     MPI_COMM_WORLD));
 | 
			
		||||
  NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,34 +4,32 @@
 | 
			
		||||
 */
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)              \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
 | 
			
		||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)         \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)  \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
 | 
			
		||||
  AT_DISPATCH_SWITCH(                                             \
 | 
			
		||||
    TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
 | 
			
		||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
 | 
			
		||||
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)     \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)   \
 | 
			
		||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)   \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)    \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)     \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)           \
 | 
			
		||||
  AT_DISPATCH_SWITCH(                                                    \
 | 
			
		||||
    TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
 | 
			
		||||
    
 | 
			
		||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)             \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)      \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)      \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)     \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)       \
 | 
			
		||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
 | 
			
		||||
  AT_DISPATCH_SWITCH(TYPE, NAME,                               \
 | 
			
		||||
                     VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)         \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)  \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)  \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)   \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)             \
 | 
			
		||||
  AT_DISPATCH_SWITCH(                                             \
 | 
			
		||||
    TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
 | 
			
		||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
 | 
			
		||||
  AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
 | 
			
		||||
 | 
			
		||||
@ -1,26 +1,34 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
#include "reduction_utils.cuh"
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #include <cuda_bf16.h>
 | 
			
		||||
  #include <cuda_fp16.h>
 | 
			
		||||
#else
 | 
			
		||||
  #include <hip/hip_bf16.h>
 | 
			
		||||
  #include <hip/hip_fp16.h>
 | 
			
		||||
 | 
			
		||||
using __nv_bfloat16 = __hip_bfloat16;
 | 
			
		||||
using __nv_bfloat162 = __hip_bfloat162;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
// TODO(woosuk): Further optimize this kernel.
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void rms_norm_kernel(
 | 
			
		||||
  scalar_t* __restrict__ out,             // [..., hidden_size]
 | 
			
		||||
  const scalar_t* __restrict__ input,     // [..., hidden_size]
 | 
			
		||||
  const scalar_t* __restrict__ weight,    // [hidden_size]
 | 
			
		||||
  const float epsilon,
 | 
			
		||||
  const int num_tokens,
 | 
			
		||||
  const int hidden_size) {
 | 
			
		||||
    scalar_t* __restrict__ out,           // [..., hidden_size]
 | 
			
		||||
    const scalar_t* __restrict__ input,   // [..., hidden_size]
 | 
			
		||||
    const scalar_t* __restrict__ weight,  // [hidden_size]
 | 
			
		||||
    const float epsilon, const int num_tokens, const int hidden_size) {
 | 
			
		||||
  __shared__ float s_variance;
 | 
			
		||||
  float variance = 0.0f;
 | 
			
		||||
 | 
			
		||||
  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
 | 
			
		||||
    const float x = (float) input[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    const float x = (float)input[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    variance += x * x;
 | 
			
		||||
  }
 | 
			
		||||
  variance = blockReduceSum<float>(variance);
 | 
			
		||||
@ -30,48 +38,260 @@ __global__ void rms_norm_kernel(
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
 | 
			
		||||
    float x = (float) input[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
 | 
			
		||||
    float x = (float)input[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    out[blockIdx.x * hidden_size + idx] =
 | 
			
		||||
        ((scalar_t)(x * s_variance)) * weight[idx];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO: Further optimize this kernel.
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
__global__ void fused_add_rms_norm_kernel(
 | 
			
		||||
  scalar_t* __restrict__ input,           // [..., hidden_size]
 | 
			
		||||
  scalar_t* __restrict__ residual,        // [..., hidden_size]
 | 
			
		||||
  const scalar_t* __restrict__ weight,    // [hidden_size]
 | 
			
		||||
  const float epsilon,
 | 
			
		||||
  const int num_tokens,
 | 
			
		||||
  const int hidden_size) {
 | 
			
		||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
 | 
			
		||||
   and the associated type conversions within HIP/CUDA. These helpers need
 | 
			
		||||
   to be implemented for now because the relevant type conversion
 | 
			
		||||
   operators/constructors are not consistently implemented by HIP/CUDA, so
 | 
			
		||||
   a generic conversion via type casts cannot be implemented.
 | 
			
		||||
 | 
			
		||||
   Each struct should have the member static constexpr bool `exists`:
 | 
			
		||||
   If false, the optimized kernel is not used for the corresponding torch type.
 | 
			
		||||
   If true, the struct should be fully defined as shown in the examples below.
 | 
			
		||||
 */
 | 
			
		||||
template <typename torch_type>
 | 
			
		||||
struct _typeConvert {
 | 
			
		||||
  static constexpr bool exists = false;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
 | 
			
		||||
// CUDA < 12.0 runs into issues with packed type conversion
 | 
			
		||||
template <>
 | 
			
		||||
struct _typeConvert<c10::Half> {
 | 
			
		||||
  static constexpr bool exists = true;
 | 
			
		||||
  using hip_type = __half;
 | 
			
		||||
  using packed_hip_type = __half2;
 | 
			
		||||
 | 
			
		||||
  __device__ static inline float convert(hip_type x) { return __half2float(x); }
 | 
			
		||||
  __device__ static inline float2 convert(packed_hip_type x) {
 | 
			
		||||
    return __half22float2(x);
 | 
			
		||||
  }
 | 
			
		||||
  __device__ static inline hip_type convert(float x) {
 | 
			
		||||
    return __float2half_rn(x);
 | 
			
		||||
  }
 | 
			
		||||
  __device__ static inline packed_hip_type convert(float2 x) {
 | 
			
		||||
    return __float22half2_rn(x);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 | 
			
		||||
// CUDA_ARCH < 800 does not have BF16 support
 | 
			
		||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
 | 
			
		||||
template <>
 | 
			
		||||
struct _typeConvert<c10::BFloat16> {
 | 
			
		||||
  static constexpr bool exists = true;
 | 
			
		||||
  using hip_type = __nv_bfloat16;
 | 
			
		||||
  using packed_hip_type = __nv_bfloat162;
 | 
			
		||||
 | 
			
		||||
  __device__ static inline float convert(hip_type x) {
 | 
			
		||||
    return __bfloat162float(x);
 | 
			
		||||
  }
 | 
			
		||||
  __device__ static inline float2 convert(packed_hip_type x) {
 | 
			
		||||
    return __bfloat1622float2(x);
 | 
			
		||||
  }
 | 
			
		||||
  __device__ static inline hip_type convert(float x) {
 | 
			
		||||
    return __float2bfloat16(x);
 | 
			
		||||
  }
 | 
			
		||||
  __device__ static inline packed_hip_type convert(float2 x) {
 | 
			
		||||
    return __float22bfloat162_rn(x);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
  #endif  // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 | 
			
		||||
#endif    // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
 | 
			
		||||
          // 12000))
 | 
			
		||||
 | 
			
		||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
 | 
			
		||||
   for appropriate specializations of fused_add_rms_norm_kernel.
 | 
			
		||||
   Only functions that are necessary in that kernel are implemented.
 | 
			
		||||
   Alignment to 16 bytes is required to use 128-bit global memory ops.
 | 
			
		||||
 */
 | 
			
		||||
template <typename scalar_t, int width>
 | 
			
		||||
struct alignas(16) _f16Vec {
 | 
			
		||||
  /* Not theoretically necessary that width is a power of 2 but should
 | 
			
		||||
     almost always be the case for optimization purposes */
 | 
			
		||||
  static_assert(width > 0 && (width & (width - 1)) == 0,
 | 
			
		||||
                "Width is not a positive power of 2!");
 | 
			
		||||
  using Converter = _typeConvert<scalar_t>;
 | 
			
		||||
  using T1 = typename Converter::hip_type;
 | 
			
		||||
  using T2 = typename Converter::packed_hip_type;
 | 
			
		||||
  T1 data[width];
 | 
			
		||||
 | 
			
		||||
  __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
 | 
			
		||||
    if constexpr (width % 2 == 0) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; i += 2) {
 | 
			
		||||
        T2 temp{data[i], data[i + 1]};
 | 
			
		||||
        temp += T2{other.data[i], other.data[i + 1]};
 | 
			
		||||
        data[i] = temp.x;
 | 
			
		||||
        data[i + 1] = temp.y;
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; ++i) data[i] += other.data[i];
 | 
			
		||||
    }
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
 | 
			
		||||
    if constexpr (width % 2 == 0) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; i += 2) {
 | 
			
		||||
        T2 temp{data[i], data[i + 1]};
 | 
			
		||||
        temp *= T2{other.data[i], other.data[i + 1]};
 | 
			
		||||
        data[i] = temp.x;
 | 
			
		||||
        data[i + 1] = temp.y;
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; ++i) data[i] *= other.data[i];
 | 
			
		||||
    }
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __device__ _f16Vec& operator*=(const float scale) {
 | 
			
		||||
    if constexpr (width % 2 == 0) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; i += 2) {
 | 
			
		||||
        float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
 | 
			
		||||
        temp_f.x *= scale;
 | 
			
		||||
        temp_f.y *= scale;
 | 
			
		||||
        T2 temp = Converter::convert(temp_f);
 | 
			
		||||
        data[i] = temp.x;
 | 
			
		||||
        data[i + 1] = temp.y;
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; ++i) {
 | 
			
		||||
        float temp = Converter::convert(data[i]) * scale;
 | 
			
		||||
        data[i] = Converter::convert(temp);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return *this;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __device__ float sum_squares() const {
 | 
			
		||||
    float result = 0.0f;
 | 
			
		||||
    if constexpr (width % 2 == 0) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; i += 2) {
 | 
			
		||||
        float2 z = Converter::convert(T2{data[i], data[i + 1]});
 | 
			
		||||
        result += z.x * z.x + z.y * z.y;
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int i = 0; i < width; ++i) {
 | 
			
		||||
        float x = Converter::convert(data[i]);
 | 
			
		||||
        result += x * x;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/* Function specialization in the case of FP16/BF16 tensors.
 | 
			
		||||
   Additional optimizations we can make in this case are
 | 
			
		||||
   packed and vectorized operations, which help with the
 | 
			
		||||
   memory latency bottleneck. */
 | 
			
		||||
template <typename scalar_t, int width>
 | 
			
		||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
 | 
			
		||||
fused_add_rms_norm_kernel(
 | 
			
		||||
    scalar_t* __restrict__ input,         // [..., hidden_size]
 | 
			
		||||
    scalar_t* __restrict__ residual,      // [..., hidden_size]
 | 
			
		||||
    const scalar_t* __restrict__ weight,  // [hidden_size]
 | 
			
		||||
    const float epsilon, const int num_tokens, const int hidden_size) {
 | 
			
		||||
  // Sanity checks on our vector struct and type-punned pointer arithmetic
 | 
			
		||||
  static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
 | 
			
		||||
  static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
 | 
			
		||||
 | 
			
		||||
  const int vec_hidden_size = hidden_size / width;
 | 
			
		||||
  __shared__ float s_variance;
 | 
			
		||||
  float variance = 0.0f;
 | 
			
		||||
  /* These and the argument pointers are all declared `restrict` as they are
 | 
			
		||||
     not aliased in practice. Argument pointers should not be dereferenced
 | 
			
		||||
     in this kernel as that would be undefined behavior */
 | 
			
		||||
  auto* __restrict__ input_v =
 | 
			
		||||
      reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
 | 
			
		||||
  auto* __restrict__ residual_v =
 | 
			
		||||
      reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
 | 
			
		||||
  auto* __restrict__ weight_v =
 | 
			
		||||
      reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
 | 
			
		||||
 | 
			
		||||
  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
 | 
			
		||||
    int id = blockIdx.x * vec_hidden_size + idx;
 | 
			
		||||
    _f16Vec<scalar_t, width> temp = input_v[id];
 | 
			
		||||
    temp += residual_v[id];
 | 
			
		||||
    variance += temp.sum_squares();
 | 
			
		||||
    residual_v[id] = temp;
 | 
			
		||||
  }
 | 
			
		||||
  /* Keep the following if-else block in sync with the
 | 
			
		||||
     calculation of max_block_size in fused_add_rms_norm */
 | 
			
		||||
  if (num_tokens < 256) {
 | 
			
		||||
    variance = blockReduceSum<float, 1024>(variance);
 | 
			
		||||
  } else
 | 
			
		||||
    variance = blockReduceSum<float, 256>(variance);
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
    s_variance = rsqrtf(variance / hidden_size + epsilon);
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
 | 
			
		||||
    int id = blockIdx.x * vec_hidden_size + idx;
 | 
			
		||||
    _f16Vec<scalar_t, width> temp = residual_v[id];
 | 
			
		||||
    temp *= s_variance;
 | 
			
		||||
    temp *= weight_v[idx];
 | 
			
		||||
    input_v[id] = temp;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* Generic fused_add_rms_norm_kernel
 | 
			
		||||
   The width field is not used here but necessary for other specializations.
 | 
			
		||||
 */
 | 
			
		||||
template <typename scalar_t, int width>
 | 
			
		||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
 | 
			
		||||
fused_add_rms_norm_kernel(
 | 
			
		||||
    scalar_t* __restrict__ input,         // [..., hidden_size]
 | 
			
		||||
    scalar_t* __restrict__ residual,      // [..., hidden_size]
 | 
			
		||||
    const scalar_t* __restrict__ weight,  // [hidden_size]
 | 
			
		||||
    const float epsilon, const int num_tokens, const int hidden_size) {
 | 
			
		||||
  __shared__ float s_variance;
 | 
			
		||||
  float variance = 0.0f;
 | 
			
		||||
 | 
			
		||||
  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
 | 
			
		||||
    float x = (float) input[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    x += (float) residual[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    scalar_t z = input[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    z += residual[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    float x = (float)z;
 | 
			
		||||
    variance += x * x;
 | 
			
		||||
    residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
 | 
			
		||||
    residual[blockIdx.x * hidden_size + idx] = z;
 | 
			
		||||
  }
 | 
			
		||||
  variance = blockReduceSum<float>(variance);
 | 
			
		||||
  /* Keep the following if-else block in sync with the
 | 
			
		||||
     calculation of max_block_size in fused_add_rms_norm */
 | 
			
		||||
  if (num_tokens < 256) {
 | 
			
		||||
    variance = blockReduceSum<float, 1024>(variance);
 | 
			
		||||
  } else
 | 
			
		||||
    variance = blockReduceSum<float, 256>(variance);
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
    s_variance = rsqrtf(variance / hidden_size + epsilon);
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
 | 
			
		||||
    float x = (float) residual[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
 | 
			
		||||
    float x = (float)residual[blockIdx.x * hidden_size + idx];
 | 
			
		||||
    input[blockIdx.x * hidden_size + idx] =
 | 
			
		||||
        ((scalar_t)(x * s_variance)) * weight[idx];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
void rms_norm(
 | 
			
		||||
  torch::Tensor& out,      // [..., hidden_size]
 | 
			
		||||
  torch::Tensor& input,    // [..., hidden_size]
 | 
			
		||||
  torch::Tensor& weight,   // [hidden_size]
 | 
			
		||||
  float epsilon) {
 | 
			
		||||
void rms_norm(torch::Tensor& out,     // [..., hidden_size]
 | 
			
		||||
              torch::Tensor& input,   // [..., hidden_size]
 | 
			
		||||
              torch::Tensor& weight,  // [hidden_size]
 | 
			
		||||
              double epsilon) {
 | 
			
		||||
  int hidden_size = input.size(-1);
 | 
			
		||||
  int num_tokens = input.numel() / hidden_size;
 | 
			
		||||
 | 
			
		||||
@ -79,42 +299,54 @@ void rms_norm(
 | 
			
		||||
  dim3 block(std::min(hidden_size, 1024));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    input.scalar_type(),
 | 
			
		||||
    "rms_norm_kernel",
 | 
			
		||||
    [&] {
 | 
			
		||||
      vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
        out.data_ptr<scalar_t>(),
 | 
			
		||||
        input.data_ptr<scalar_t>(),
 | 
			
		||||
        weight.data_ptr<scalar_t>(),
 | 
			
		||||
        epsilon,
 | 
			
		||||
        num_tokens,
 | 
			
		||||
        hidden_size);
 | 
			
		||||
    });
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
 | 
			
		||||
    vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
        out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
 | 
			
		||||
        weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void fused_add_rms_norm(
 | 
			
		||||
  torch::Tensor& input,    // [..., hidden_size]
 | 
			
		||||
  torch::Tensor& residual, // [..., hidden_size]
 | 
			
		||||
  torch::Tensor& weight,   // [hidden_size]
 | 
			
		||||
  float epsilon) {
 | 
			
		||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width)                                       \
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(                                                \
 | 
			
		||||
      input.scalar_type(), "fused_add_rms_norm_kernel", [&] {                  \
 | 
			
		||||
        vllm::fused_add_rms_norm_kernel<scalar_t, width>                       \
 | 
			
		||||
            <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),           \
 | 
			
		||||
                                         residual.data_ptr<scalar_t>(),        \
 | 
			
		||||
                                         weight.data_ptr<scalar_t>(), epsilon, \
 | 
			
		||||
                                         num_tokens, hidden_size);             \
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
void fused_add_rms_norm(torch::Tensor& input,     // [..., hidden_size]
 | 
			
		||||
                        torch::Tensor& residual,  // [..., hidden_size]
 | 
			
		||||
                        torch::Tensor& weight,    // [hidden_size]
 | 
			
		||||
                        double epsilon) {
 | 
			
		||||
  int hidden_size = input.size(-1);
 | 
			
		||||
  int num_tokens = input.numel() / hidden_size;
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(hidden_size, 1024));
 | 
			
		||||
  /* This kernel is memory-latency bound in many scenarios.
 | 
			
		||||
     When num_tokens is large, a smaller block size allows
 | 
			
		||||
     for increased block occupancy on CUs and better latency
 | 
			
		||||
     hiding on global mem ops. */
 | 
			
		||||
  const int max_block_size = (num_tokens < 256) ? 1024 : 256;
 | 
			
		||||
  dim3 block(std::min(hidden_size, max_block_size));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    input.scalar_type(),
 | 
			
		||||
    "fused_add_rms_norm_kernel",
 | 
			
		||||
    [&] {
 | 
			
		||||
      vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
        input.data_ptr<scalar_t>(),
 | 
			
		||||
        residual.data_ptr<scalar_t>(),
 | 
			
		||||
        weight.data_ptr<scalar_t>(),
 | 
			
		||||
        epsilon,
 | 
			
		||||
        num_tokens,
 | 
			
		||||
        hidden_size);
 | 
			
		||||
    });
 | 
			
		||||
  /*If the tensor types are FP16/BF16, try to use the optimized kernel
 | 
			
		||||
    with packed + vectorized ops.
 | 
			
		||||
    Max optimization is achieved with a width-8 vector of FP16/BF16s
 | 
			
		||||
    since we can load at most 128 bits at once in a global memory op.
 | 
			
		||||
    However, this requires each tensor's data to be aligned to 16
 | 
			
		||||
    bytes.
 | 
			
		||||
   */
 | 
			
		||||
  auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
 | 
			
		||||
  auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
 | 
			
		||||
  auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
 | 
			
		||||
  bool ptrs_are_aligned =
 | 
			
		||||
      inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
 | 
			
		||||
  if (ptrs_are_aligned && hidden_size % 8 == 0) {
 | 
			
		||||
    LAUNCH_FUSED_ADD_RMS_NORM(8);
 | 
			
		||||
  } else {
 | 
			
		||||
    LAUNCH_FUSED_ADD_RMS_NORM(0);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,7 +0,0 @@
 | 
			
		||||
#include "moe_ops.h"
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
  m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
 | 
			
		||||
}
 | 
			
		||||
@ -1,9 +1,7 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
 | 
			
		||||
void topk_softmax(
 | 
			
		||||
  torch::Tensor& topk_weights,
 | 
			
		||||
  torch::Tensor& topk_indices,
 | 
			
		||||
  torch::Tensor& token_expert_indices,
 | 
			
		||||
  torch::Tensor& gating_output);
 | 
			
		||||
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
 | 
			
		||||
                  torch::Tensor& token_expert_indices,
 | 
			
		||||
                  torch::Tensor& gating_output);
 | 
			
		||||
 | 
			
		||||
@ -16,18 +16,25 @@
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include "../cuda_compat.h"
 | 
			
		||||
 | 
			
		||||
#include <cub/cub.cuh>
 | 
			
		||||
#include <cub/util_type.cuh>
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
    #include <cub/util_type.cuh>
 | 
			
		||||
    #include <cub/cub.cuh>
 | 
			
		||||
#else
 | 
			
		||||
    #include <hipcub/util_type.hpp>
 | 
			
		||||
    #include <hipcub/hipcub.hpp>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
 | 
			
		||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace moe {
 | 
			
		||||
 | 
			
		||||
static constexpr int WARP_SIZE = 32;
 | 
			
		||||
 | 
			
		||||
/// Aligned array type
 | 
			
		||||
template <
 | 
			
		||||
    typename T,
 | 
			
		||||
@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
 | 
			
		||||
    {
 | 
			
		||||
        thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
 | 
			
		||||
        thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // From this point, thread max in all the threads have the max within the row.
 | 
			
		||||
@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
 | 
			
		||||
    {
 | 
			
		||||
        row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
 | 
			
		||||
        row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
 | 
			
		||||
@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
 | 
			
		||||
#pragma unroll
 | 
			
		||||
        for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
 | 
			
		||||
        {
 | 
			
		||||
            float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
 | 
			
		||||
            int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
 | 
			
		||||
            float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
 | 
			
		||||
            int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
 | 
			
		||||
 | 
			
		||||
            // We want lower indices to "win" in every thread so we break ties this way
 | 
			
		||||
            if (other_max > max_val || (other_max == max_val && other_expert < expert))
 | 
			
		||||
@ -383,7 +390,7 @@ struct TopkConstants
 | 
			
		||||
{
 | 
			
		||||
    static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
 | 
			
		||||
    static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
 | 
			
		||||
    static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
 | 
			
		||||
    static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
 | 
			
		||||
    static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
 | 
			
		||||
    static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
 | 
			
		||||
    static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
 | 
			
		||||
@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
 | 
			
		||||
{
 | 
			
		||||
    static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
 | 
			
		||||
 | 
			
		||||
    static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
 | 
			
		||||
    static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
 | 
			
		||||
    using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
 | 
			
		||||
    static constexpr int VPT = Constants::VPT;
 | 
			
		||||
    static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								csrc/moe/torch_bindings.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								csrc/moe/torch_bindings.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
#include "registration.h"
 | 
			
		||||
#include "moe_ops.h"
 | 
			
		||||
 | 
			
		||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
  // Apply topk softmax to the gating outputs.
 | 
			
		||||
  m.def(
 | 
			
		||||
      "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
 | 
			
		||||
      "token_expert_indices, Tensor gating_output) -> ()");
 | 
			
		||||
  m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
@ -7,119 +7,128 @@
 | 
			
		||||
#include "cuda_compat.h"
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
 | 
			
		||||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
 | 
			
		||||
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
 | 
			
		||||
    // don't worry about overflow because num_experts is relatively small
 | 
			
		||||
    return row * total_col + col;
 | 
			
		||||
}
 | 
			
		||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
 | 
			
		||||
                                         int32_t col) {
 | 
			
		||||
  // don't worry about overflow because num_experts is relatively small
 | 
			
		||||
  return row * total_col + col;
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, 
 | 
			
		||||
                                int32_t *sorted_token_ids, 
 | 
			
		||||
                                int32_t *expert_ids, 
 | 
			
		||||
                                int32_t *total_tokens_post_pad,
 | 
			
		||||
                                int32_t num_experts, 
 | 
			
		||||
                                int32_t block_size, 
 | 
			
		||||
                                size_t numel) {
 | 
			
		||||
    const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
 | 
			
		||||
    const size_t start_idx = threadIdx.x * tokens_per_thread;
 | 
			
		||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
 | 
			
		||||
                                            int32_t* sorted_token_ids,
 | 
			
		||||
                                            int32_t* expert_ids,
 | 
			
		||||
                                            int32_t* total_tokens_post_pad,
 | 
			
		||||
                                            int32_t num_experts,
 | 
			
		||||
                                            int32_t block_size, size_t numel) {
 | 
			
		||||
  const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
 | 
			
		||||
  const size_t start_idx = threadIdx.x * tokens_per_thread;
 | 
			
		||||
 | 
			
		||||
    extern __shared__ int32_t shared_mem[];
 | 
			
		||||
  extern __shared__ int32_t shared_mem[];
 | 
			
		||||
 | 
			
		||||
    int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
 | 
			
		||||
    int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
 | 
			
		||||
  int32_t* tokens_cnts =
 | 
			
		||||
      shared_mem;  // 2d tensor with shape (num_experts + 1, num_experts)
 | 
			
		||||
  int32_t* cumsum =
 | 
			
		||||
      shared_mem + (num_experts + 1) *
 | 
			
		||||
                       num_experts;  // 1d tensor with shape (num_experts + 1)
 | 
			
		||||
 | 
			
		||||
    for (int i = 0; i < num_experts; ++i) {
 | 
			
		||||
        tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
 | 
			
		||||
  for (int i = 0; i < num_experts; ++i) {
 | 
			
		||||
    tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * In the first step we compute token_cnts[thread_index + 1][expert_index],
 | 
			
		||||
   * which counts how many tokens in the token shard of thread_index are
 | 
			
		||||
   * assigned to expert expert_index.
 | 
			
		||||
   */
 | 
			
		||||
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
 | 
			
		||||
    ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  // For each expert we accumulate the token counts from the different threads.
 | 
			
		||||
  tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
 | 
			
		||||
  for (int i = 1; i <= blockDim.x; ++i) {
 | 
			
		||||
    tokens_cnts[index(num_experts, i, threadIdx.x)] +=
 | 
			
		||||
        tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  // We accumulate the token counts of all experts in thread 0.
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
    cumsum[0] = 0;
 | 
			
		||||
    for (int i = 1; i <= num_experts; ++i) {
 | 
			
		||||
      cumsum[i] = cumsum[i - 1] +
 | 
			
		||||
                  CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
 | 
			
		||||
                          block_size) *
 | 
			
		||||
                      block_size;
 | 
			
		||||
    }
 | 
			
		||||
    *total_tokens_post_pad = cumsum[num_experts];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
    * In the first step we compute token_cnts[thread_index + 1][expert_index],
 | 
			
		||||
    * which counts how many tokens in the token shard of thread_index are assigned
 | 
			
		||||
    * to expert expert_index.
 | 
			
		||||
    */
 | 
			
		||||
    for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
 | 
			
		||||
        ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; 
 | 
			
		||||
    }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
  /**
 | 
			
		||||
   * For each expert, each thread processes the tokens of the corresponding
 | 
			
		||||
   * blocks and stores the corresponding expert_id for each block.
 | 
			
		||||
   */
 | 
			
		||||
  for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
 | 
			
		||||
       i += block_size) {
 | 
			
		||||
    expert_ids[i / block_size] = threadIdx.x;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    // For each expert we accumulate the token counts from the different threads.
 | 
			
		||||
    tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
 | 
			
		||||
    for (int i = 1; i <= blockDim.x; ++i) {
 | 
			
		||||
        tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    
 | 
			
		||||
    // We accumulate the token counts of all experts in thread 0.
 | 
			
		||||
    if (threadIdx.x == 0) {
 | 
			
		||||
        cumsum[0] = 0;
 | 
			
		||||
        for (int i = 1; i <= num_experts; ++i) {
 | 
			
		||||
            cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
 | 
			
		||||
        }
 | 
			
		||||
        *total_tokens_post_pad = cumsum[num_experts];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
    * For each expert, each thread processes the tokens of the corresponding blocks
 | 
			
		||||
    * and stores the corresponding expert_id for each block.
 | 
			
		||||
    */
 | 
			
		||||
    for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
 | 
			
		||||
        expert_ids[i / block_size] = threadIdx.x;
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    /**
 | 
			
		||||
    * Each thread processes a token shard, calculating the index of each token after
 | 
			
		||||
    * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
 | 
			
		||||
    * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
 | 
			
		||||
    * where * represents a padding value(preset in python).
 | 
			
		||||
    */
 | 
			
		||||
    for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
 | 
			
		||||
        int32_t expert_id = topk_ids[i];
 | 
			
		||||
        /** The cumsum[expert_id] stores the starting index of the tokens that the
 | 
			
		||||
        * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
 | 
			
		||||
        * stores the indices of the tokens processed by the expert with expert_id within
 | 
			
		||||
        * the current thread's token shard.
 | 
			
		||||
        */
 | 
			
		||||
        int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
 | 
			
		||||
        sorted_token_ids[rank_post_pad] = i;
 | 
			
		||||
        ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
  /**
 | 
			
		||||
   * Each thread processes a token shard, calculating the index of each token
 | 
			
		||||
   * after sorting by expert number. Given the example topk_ids =
 | 
			
		||||
   * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
 | 
			
		||||
   * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
 | 
			
		||||
   * padding value(preset in python).
 | 
			
		||||
   */
 | 
			
		||||
  for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
 | 
			
		||||
    int32_t expert_id = topk_ids[i];
 | 
			
		||||
    /** The cumsum[expert_id] stores the starting index of the tokens that the
 | 
			
		||||
     * expert with expert_id needs to process, and
 | 
			
		||||
     * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
 | 
			
		||||
     * processed by the expert with expert_id within the current thread's token
 | 
			
		||||
     * shard.
 | 
			
		||||
     */
 | 
			
		||||
    int32_t rank_post_pad =
 | 
			
		||||
        tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
 | 
			
		||||
        cumsum[expert_id];
 | 
			
		||||
    sorted_token_ids[rank_post_pad] = i;
 | 
			
		||||
    ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
void moe_align_block_size(
 | 
			
		||||
    torch::Tensor topk_ids,
 | 
			
		||||
    int num_experts,
 | 
			
		||||
    int block_size,
 | 
			
		||||
    torch::Tensor sorted_token_ids,
 | 
			
		||||
    torch::Tensor experts_ids,
 | 
			
		||||
    torch::Tensor num_tokens_post_pad) {
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    VLLM_DISPATCH_INTEGRAL_TYPES(
 | 
			
		||||
        topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
 | 
			
		||||
        // calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
 | 
			
		||||
        const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
 | 
			
		||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
 | 
			
		||||
                          int64_t block_size, torch::Tensor sorted_token_ids,
 | 
			
		||||
                          torch::Tensor experts_ids,
 | 
			
		||||
                          torch::Tensor num_tokens_post_pad) {
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_INTEGRAL_TYPES(
 | 
			
		||||
      topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
 | 
			
		||||
        // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
 | 
			
		||||
        // tensors
 | 
			
		||||
        const int32_t shared_mem =
 | 
			
		||||
            ((num_experts + 1) * num_experts + (num_experts + 1)) *
 | 
			
		||||
            sizeof(int32_t);
 | 
			
		||||
 | 
			
		||||
        // set dynamic shared mem
 | 
			
		||||
        auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
 | 
			
		||||
        AT_CUDA_CHECK(
 | 
			
		||||
            VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
 | 
			
		||||
        AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
 | 
			
		||||
            (void*)kernel, shared_mem));
 | 
			
		||||
        kernel<<<1, num_experts, shared_mem, stream>>>(
 | 
			
		||||
            topk_ids.data_ptr<scalar_t>(),
 | 
			
		||||
            sorted_token_ids.data_ptr<int32_t>(), 
 | 
			
		||||
            experts_ids.data_ptr<int32_t>(), 
 | 
			
		||||
            num_tokens_post_pad.data_ptr<int32_t>(), 
 | 
			
		||||
            num_experts,
 | 
			
		||||
            block_size,
 | 
			
		||||
            topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
 | 
			
		||||
            experts_ids.data_ptr<int32_t>(),
 | 
			
		||||
            num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
 | 
			
		||||
            topk_ids.numel());
 | 
			
		||||
    });
 | 
			
		||||
      });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										247
									
								
								csrc/ops.h
									
									
									
									
									
								
							
							
						
						
									
										247
									
								
								csrc/ops.h
									
									
									
									
									
								
							@ -1,159 +1,146 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/library.h>
 | 
			
		||||
 | 
			
		||||
void paged_attention_v1(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  int num_kv_heads,
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,
 | 
			
		||||
  torch::Tensor& context_lens,
 | 
			
		||||
  int block_size,
 | 
			
		||||
  int max_context_len,
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
  const std::string& kv_cache_dtype);
 | 
			
		||||
    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
 | 
			
		||||
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
 | 
			
		||||
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
 | 
			
		||||
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
 | 
			
		||||
    const int64_t blocksparse_local_blocks,
 | 
			
		||||
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
 | 
			
		||||
    const int64_t blocksparse_head_sliding_step);
 | 
			
		||||
 | 
			
		||||
void paged_attention_v2(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& exp_sums,
 | 
			
		||||
  torch::Tensor& max_logits,
 | 
			
		||||
  torch::Tensor& tmp_out,
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  int num_kv_heads,
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,
 | 
			
		||||
  torch::Tensor& context_lens,
 | 
			
		||||
  int block_size,
 | 
			
		||||
  int max_context_len,
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
  const std::string& kv_cache_dtype);
 | 
			
		||||
    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
 | 
			
		||||
    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
 | 
			
		||||
    torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
 | 
			
		||||
    torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
 | 
			
		||||
    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
    const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
 | 
			
		||||
    const int64_t blocksparse_local_blocks,
 | 
			
		||||
    const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
 | 
			
		||||
    const int64_t blocksparse_head_sliding_step);
 | 
			
		||||
 | 
			
		||||
void rms_norm(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& input,
 | 
			
		||||
  torch::Tensor& weight,
 | 
			
		||||
  float epsilon);
 | 
			
		||||
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
 | 
			
		||||
              double epsilon);
 | 
			
		||||
 | 
			
		||||
void fused_add_rms_norm(
 | 
			
		||||
  torch::Tensor& input,
 | 
			
		||||
  torch::Tensor& residual,
 | 
			
		||||
  torch::Tensor& weight,
 | 
			
		||||
  float epsilon);
 | 
			
		||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
 | 
			
		||||
                        torch::Tensor& weight, double epsilon);
 | 
			
		||||
 | 
			
		||||
void rotary_embedding(
 | 
			
		||||
  torch::Tensor& positions,
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key,
 | 
			
		||||
  int head_size,
 | 
			
		||||
  torch::Tensor& cos_sin_cache,
 | 
			
		||||
  bool is_neox);
 | 
			
		||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
 | 
			
		||||
                      torch::Tensor& key, int64_t head_size,
 | 
			
		||||
                      torch::Tensor& cos_sin_cache, bool is_neox);
 | 
			
		||||
 | 
			
		||||
void batched_rotary_embedding(
 | 
			
		||||
  torch::Tensor& positions,
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key,
 | 
			
		||||
  int head_size,
 | 
			
		||||
  torch::Tensor& cos_sin_cache,
 | 
			
		||||
  bool is_neox,
 | 
			
		||||
  int rot_dim,
 | 
			
		||||
  torch::Tensor& cos_sin_cache_offsets);
 | 
			
		||||
void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
 | 
			
		||||
                              torch::Tensor& key, int64_t head_size,
 | 
			
		||||
                              torch::Tensor& cos_sin_cache, bool is_neox,
 | 
			
		||||
                              int64_t rot_dim,
 | 
			
		||||
                              torch::Tensor& cos_sin_cache_offsets);
 | 
			
		||||
 | 
			
		||||
void silu_and_mul(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& input);
 | 
			
		||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
 | 
			
		||||
 | 
			
		||||
void gelu_and_mul(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& input);
 | 
			
		||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
 | 
			
		||||
 | 
			
		||||
void gelu_tanh_and_mul(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& input);
 | 
			
		||||
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
 | 
			
		||||
 | 
			
		||||
void gelu_new(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& input);
 | 
			
		||||
void gelu_new(torch::Tensor& out, torch::Tensor& input);
 | 
			
		||||
 | 
			
		||||
void gelu_fast(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& input);
 | 
			
		||||
void gelu_fast(torch::Tensor& out, torch::Tensor& input);
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
torch::Tensor awq_gemm(
 | 
			
		||||
  torch::Tensor _in_feats,
 | 
			
		||||
  torch::Tensor _kernel,
 | 
			
		||||
  torch::Tensor _scaling_factors,
 | 
			
		||||
  torch::Tensor _zeros,
 | 
			
		||||
  int split_k_iters);
 | 
			
		||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
 | 
			
		||||
                        const torch::Tensor& codebooks,
 | 
			
		||||
                        const torch::Tensor& scales,
 | 
			
		||||
                        const torch::Tensor& codebook_partition_sizes,
 | 
			
		||||
                        const std::optional<torch::Tensor>& bias);
 | 
			
		||||
 | 
			
		||||
torch::Tensor awq_dequantize(
 | 
			
		||||
    torch::Tensor _kernel,
 | 
			
		||||
    torch::Tensor _scaling_factors,
 | 
			
		||||
    torch::Tensor _zeros,
 | 
			
		||||
    int split_k_iters,
 | 
			
		||||
    int thx,
 | 
			
		||||
    int thy);
 | 
			
		||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
 | 
			
		||||
                           const torch::Tensor& codebooks,
 | 
			
		||||
                           const torch::Tensor& codebook_partition_sizes);
 | 
			
		||||
 | 
			
		||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
 | 
			
		||||
                       torch::Tensor _scaling_factors, torch::Tensor _zeros,
 | 
			
		||||
                       int64_t split_k_iters);
 | 
			
		||||
 | 
			
		||||
torch::Tensor awq_dequantize(torch::Tensor _kernel,
 | 
			
		||||
                             torch::Tensor _scaling_factors,
 | 
			
		||||
                             torch::Tensor _zeros, int64_t split_k_iters,
 | 
			
		||||
                             int64_t thx, int64_t thy);
 | 
			
		||||
 | 
			
		||||
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
 | 
			
		||||
                          torch::Tensor& b_scales, torch::Tensor& workspace,
 | 
			
		||||
                          int64_t size_m, int64_t size_n, int64_t size_k);
 | 
			
		||||
 | 
			
		||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
 | 
			
		||||
                                  torch::Tensor& b_meta,
 | 
			
		||||
                                  torch::Tensor& b_scales,
 | 
			
		||||
                                  torch::Tensor& workspace, int64_t num_bits,
 | 
			
		||||
                                  int64_t size_m, int64_t size_n,
 | 
			
		||||
                                  int64_t size_k);
 | 
			
		||||
 | 
			
		||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
 | 
			
		||||
                               torch::Tensor& b_scales, torch::Tensor& g_idx,
 | 
			
		||||
                               torch::Tensor& perm, torch::Tensor& workspace,
 | 
			
		||||
                               int64_t num_bits, int64_t size_m, int64_t size_n,
 | 
			
		||||
                               int64_t size_k, bool is_k_full);
 | 
			
		||||
 | 
			
		||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
 | 
			
		||||
                                 int64_t size_k, int64_t size_n,
 | 
			
		||||
                                 int64_t num_bits);
 | 
			
		||||
 | 
			
		||||
void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
 | 
			
		||||
                          torch::Tensor const& b, torch::Tensor const& a_scales,
 | 
			
		||||
                          torch::Tensor const& b_scales);
 | 
			
		||||
 | 
			
		||||
torch::Tensor marlin_gemm(
 | 
			
		||||
    torch::Tensor& a, 
 | 
			
		||||
    torch::Tensor& b_q_weight,
 | 
			
		||||
    torch::Tensor& b_scales, 
 | 
			
		||||
    torch::Tensor& workspace,
 | 
			
		||||
    int64_t size_m, 
 | 
			
		||||
    int64_t size_n, 
 | 
			
		||||
    int64_t size_k);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
void squeezellm_gemm(
 | 
			
		||||
  torch::Tensor vec,
 | 
			
		||||
  torch::Tensor mat,
 | 
			
		||||
  torch::Tensor mul,
 | 
			
		||||
  torch::Tensor lookup_table);
 | 
			
		||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
 | 
			
		||||
                              torch::Tensor const& scale);
 | 
			
		||||
 | 
			
		||||
torch::Tensor gptq_gemm(
 | 
			
		||||
  torch::Tensor a,
 | 
			
		||||
  torch::Tensor b_q_weight,
 | 
			
		||||
  torch::Tensor b_gptq_qzeros,
 | 
			
		||||
  torch::Tensor b_gptq_scales,
 | 
			
		||||
  torch::Tensor b_g_idx,
 | 
			
		||||
  bool use_exllama,
 | 
			
		||||
  int bit);
 | 
			
		||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
 | 
			
		||||
                               torch::Tensor& scales);
 | 
			
		||||
 | 
			
		||||
void gptq_shuffle(
 | 
			
		||||
  torch::Tensor q_weight,
 | 
			
		||||
  torch::Tensor q_perm,
 | 
			
		||||
  int bit);
 | 
			
		||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
 | 
			
		||||
                     torch::Tensor lookup_table);
 | 
			
		||||
 | 
			
		||||
void moe_align_block_size(
 | 
			
		||||
  torch::Tensor topk_ids,
 | 
			
		||||
  int num_experts,
 | 
			
		||||
  int block_size,
 | 
			
		||||
  torch::Tensor sorted_token_ids,
 | 
			
		||||
  torch::Tensor experts_ids,
 | 
			
		||||
  torch::Tensor num_tokens_post_pad);
 | 
			
		||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
 | 
			
		||||
                        torch::Tensor b_gptq_qzeros,
 | 
			
		||||
                        torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
 | 
			
		||||
                        bool use_exllama, int64_t bit);
 | 
			
		||||
 | 
			
		||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
 | 
			
		||||
 | 
			
		||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
 | 
			
		||||
                             torch::Tensor& scale);
 | 
			
		||||
 | 
			
		||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
 | 
			
		||||
                              torch::Tensor& scale);
 | 
			
		||||
 | 
			
		||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
 | 
			
		||||
                          int64_t block_size, torch::Tensor sorted_token_ids,
 | 
			
		||||
                          torch::Tensor experts_ids,
 | 
			
		||||
                          torch::Tensor num_tokens_post_pad);
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
using fptr_t = uint64_t;
 | 
			
		||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
 | 
			
		||||
                    const std::vector<std::string> &handles,
 | 
			
		||||
                    const std::vector<int64_t> &offsets, int rank,
 | 
			
		||||
                    bool full_nvlink);
 | 
			
		||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
 | 
			
		||||
using fptr_t = int64_t;
 | 
			
		||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
 | 
			
		||||
                      const std::vector<std::string>& handles,
 | 
			
		||||
                      const std::vector<int64_t>& offsets, int64_t rank,
 | 
			
		||||
                      bool full_nvlink);
 | 
			
		||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
 | 
			
		||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
 | 
			
		||||
                      torch::Tensor &out);
 | 
			
		||||
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
 | 
			
		||||
                      bool full_nvlink);
 | 
			
		||||
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
 | 
			
		||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
 | 
			
		||||
                      torch::Tensor& out);
 | 
			
		||||
void dispose(fptr_t _fa);
 | 
			
		||||
int meta_size();
 | 
			
		||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
 | 
			
		||||
                     const std::vector<std::string> &handles,
 | 
			
		||||
                     const std::vector<int64_t> &offsets);
 | 
			
		||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
 | 
			
		||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
 | 
			
		||||
                            const std::vector<std::vector<int64_t>> &offsets);
 | 
			
		||||
int64_t meta_size();
 | 
			
		||||
void register_buffer(fptr_t _fa, torch::Tensor& t,
 | 
			
		||||
                     const std::vector<std::string>& handles,
 | 
			
		||||
                     const std::vector<int64_t>& offsets);
 | 
			
		||||
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
 | 
			
		||||
    fptr_t _fa);
 | 
			
		||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
 | 
			
		||||
                            const std::vector<std::vector<int64_t>>& offsets);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <torch/all.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
@ -7,14 +7,10 @@
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t, bool IS_NEOX>
 | 
			
		||||
template <typename scalar_t, bool IS_NEOX>
 | 
			
		||||
inline __device__ void apply_token_rotary_embedding(
 | 
			
		||||
  scalar_t* __restrict__ arr,
 | 
			
		||||
  const scalar_t* __restrict__ cos_ptr,
 | 
			
		||||
  const scalar_t* __restrict__ sin_ptr,
 | 
			
		||||
  int rot_offset,
 | 
			
		||||
  int embed_dim)
 | 
			
		||||
{
 | 
			
		||||
    scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
 | 
			
		||||
    const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
 | 
			
		||||
  int x_index, y_index;
 | 
			
		||||
  scalar_t cos, sin;
 | 
			
		||||
  if (IS_NEOX) {
 | 
			
		||||
@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
 | 
			
		||||
  arr[y_index] = y * cos + x * sin;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t, bool IS_NEOX>
 | 
			
		||||
template <typename scalar_t, bool IS_NEOX>
 | 
			
		||||
inline __device__ void apply_rotary_embedding(
 | 
			
		||||
  scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
 | 
			
		||||
  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
 | 
			
		||||
  const scalar_t* cache_ptr,
 | 
			
		||||
  const int head_size,
 | 
			
		||||
  const int num_heads,
 | 
			
		||||
  const int num_kv_heads,
 | 
			
		||||
  const int rot_dim,
 | 
			
		||||
  const int token_idx,
 | 
			
		||||
  const int64_t query_stride,
 | 
			
		||||
  const int64_t key_stride)
 | 
			
		||||
{
 | 
			
		||||
    scalar_t* __restrict__ query,  // [batch_size, seq_len, num_heads,
 | 
			
		||||
                                   // head_size] or [num_tokens, num_heads,
 | 
			
		||||
                                   // head_size]
 | 
			
		||||
    scalar_t* __restrict__ key,    // [batch_size, seq_len, num_kv_heads,
 | 
			
		||||
                                   // head_size] or [num_tokens, num_kv_heads,
 | 
			
		||||
                                   // head_size]
 | 
			
		||||
    const scalar_t* cache_ptr, const int head_size, const int num_heads,
 | 
			
		||||
    const int num_kv_heads, const int rot_dim, const int token_idx,
 | 
			
		||||
    const int64_t query_stride, const int64_t key_stride) {
 | 
			
		||||
  const int embed_dim = rot_dim / 2;
 | 
			
		||||
  const scalar_t* cos_ptr = cache_ptr;
 | 
			
		||||
  const scalar_t* sin_ptr = cache_ptr + embed_dim;
 | 
			
		||||
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
 | 
			
		||||
    const int head_idx = i / embed_dim;
 | 
			
		||||
    const int64_t token_head = token_idx * query_stride + head_idx * head_size;
 | 
			
		||||
    const int rot_offset = i % embed_dim;
 | 
			
		||||
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
 | 
			
		||||
                                              sin_ptr, rot_offset, embed_dim);
 | 
			
		||||
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
 | 
			
		||||
        query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int nk = num_kv_heads * embed_dim;
 | 
			
		||||
@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
 | 
			
		||||
    const int head_idx = i / embed_dim;
 | 
			
		||||
    const int64_t token_head = token_idx * key_stride + head_idx * head_size;
 | 
			
		||||
    const int rot_offset = i % embed_dim;
 | 
			
		||||
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
 | 
			
		||||
                                              sin_ptr, rot_offset, embed_dim);
 | 
			
		||||
    apply_token_rotary_embedding<scalar_t, IS_NEOX>(
 | 
			
		||||
        key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t, bool IS_NEOX>
 | 
			
		||||
template <typename scalar_t, bool IS_NEOX>
 | 
			
		||||
__global__ void rotary_embedding_kernel(
 | 
			
		||||
  const int64_t* __restrict__ positions,        // [batch_size, seq_len] or [num_tokens]
 | 
			
		||||
  scalar_t* __restrict__ query,                 // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
 | 
			
		||||
  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
 | 
			
		||||
  const int rot_dim,
 | 
			
		||||
  const int64_t query_stride,
 | 
			
		||||
  const int64_t key_stride,
 | 
			
		||||
  const int num_heads,
 | 
			
		||||
  const int num_kv_heads,
 | 
			
		||||
  const int head_size) {
 | 
			
		||||
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
 | 
			
		||||
                                            // [num_tokens]
 | 
			
		||||
    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
 | 
			
		||||
                                   // head_size] or [num_tokens, num_heads,
 | 
			
		||||
                                   // head_size]
 | 
			
		||||
    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
 | 
			
		||||
                                 // head_size] or [num_tokens, num_kv_heads,
 | 
			
		||||
                                 // head_size]
 | 
			
		||||
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
 | 
			
		||||
                                                 // 2]
 | 
			
		||||
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
 | 
			
		||||
    const int num_heads, const int num_kv_heads, const int head_size) {
 | 
			
		||||
  // Each thread block is responsible for one token.
 | 
			
		||||
  const int token_idx = blockIdx.x;
 | 
			
		||||
  int64_t pos = positions[token_idx];
 | 
			
		||||
  const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
 | 
			
		||||
 | 
			
		||||
  apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
 | 
			
		||||
  apply_rotary_embedding<scalar_t, IS_NEOX>(
 | 
			
		||||
      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
 | 
			
		||||
      token_idx, query_stride, key_stride);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t, bool IS_NEOX>
 | 
			
		||||
template <typename scalar_t, bool IS_NEOX>
 | 
			
		||||
__global__ void batched_rotary_embedding_kernel(
 | 
			
		||||
  const int64_t* __restrict__ positions,              // [batch_size, seq_len] or [num_tokens]
 | 
			
		||||
  scalar_t* __restrict__ query,                       // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
 | 
			
		||||
  scalar_t* __restrict__ key,                         // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ cos_sin_cache,         // [max_position, 2, rot_dim // 2]
 | 
			
		||||
  const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len] or [num_tokens]
 | 
			
		||||
  const int rot_dim,
 | 
			
		||||
  const int64_t query_stride,
 | 
			
		||||
  const int64_t key_stride,
 | 
			
		||||
  const int num_heads,
 | 
			
		||||
  const int num_kv_heads,
 | 
			
		||||
  const int head_size) {
 | 
			
		||||
    const int64_t* __restrict__ positions,  // [batch_size, seq_len] or
 | 
			
		||||
                                            // [num_tokens]
 | 
			
		||||
    scalar_t* __restrict__ query,           // [batch_size, seq_len, num_heads,
 | 
			
		||||
                                   // head_size] or [num_tokens, num_heads,
 | 
			
		||||
                                   // head_size]
 | 
			
		||||
    scalar_t* __restrict__ key,  // [batch_size, seq_len, num_kv_heads,
 | 
			
		||||
                                 // head_size] or [num_tokens, num_kv_heads,
 | 
			
		||||
                                 // head_size]
 | 
			
		||||
    const scalar_t* __restrict__ cos_sin_cache,  // [max_position, 2, rot_dim //
 | 
			
		||||
                                                 // 2]
 | 
			
		||||
    const int64_t* __restrict__ cos_sin_cache_offsets,  // [batch_size, seq_len]
 | 
			
		||||
                                                        // or [num_tokens]
 | 
			
		||||
    const int rot_dim, const int64_t query_stride, const int64_t key_stride,
 | 
			
		||||
    const int num_heads, const int num_kv_heads, const int head_size) {
 | 
			
		||||
  // Each thread block is responsible for one token.
 | 
			
		||||
  const int token_idx = blockIdx.x;
 | 
			
		||||
  int64_t pos = positions[token_idx];
 | 
			
		||||
  int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
 | 
			
		||||
  const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
 | 
			
		||||
  const scalar_t* cache_ptr =
 | 
			
		||||
      cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
 | 
			
		||||
 | 
			
		||||
  apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
 | 
			
		||||
  apply_rotary_embedding<scalar_t, IS_NEOX>(
 | 
			
		||||
      query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
 | 
			
		||||
      token_idx, query_stride, key_stride);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
void rotary_embedding(
 | 
			
		||||
  torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens]
 | 
			
		||||
  torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
 | 
			
		||||
  torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
 | 
			
		||||
  int head_size,
 | 
			
		||||
  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
 | 
			
		||||
  bool is_neox) {
 | 
			
		||||
    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
 | 
			
		||||
    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
 | 
			
		||||
                           // [num_tokens, num_heads * head_size]
 | 
			
		||||
    torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or
 | 
			
		||||
                           // [num_tokens, num_kv_heads * head_size]
 | 
			
		||||
    int64_t head_size,
 | 
			
		||||
    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
 | 
			
		||||
    bool is_neox) {
 | 
			
		||||
  int64_t num_tokens = query.numel() / query.size(-1);
 | 
			
		||||
  int rot_dim = cos_sin_cache.size(1);
 | 
			
		||||
  int num_heads = query.size(-1) / head_size;
 | 
			
		||||
@ -132,39 +138,24 @@ void rotary_embedding(
 | 
			
		||||
  int64_t key_stride = key.stride(-2);
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
 | 
			
		||||
  dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    query.scalar_type(),
 | 
			
		||||
    "rotary_embedding",
 | 
			
		||||
    [&] {
 | 
			
		||||
      if (is_neox) {
 | 
			
		||||
        vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
 | 
			
		||||
          positions.data_ptr<int64_t>(),
 | 
			
		||||
          query.data_ptr<scalar_t>(),
 | 
			
		||||
          key.data_ptr<scalar_t>(),
 | 
			
		||||
          cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
          rot_dim,
 | 
			
		||||
          query_stride,
 | 
			
		||||
          key_stride,
 | 
			
		||||
          num_heads,
 | 
			
		||||
          num_kv_heads,
 | 
			
		||||
          head_size);
 | 
			
		||||
      } else {
 | 
			
		||||
        vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
 | 
			
		||||
          positions.data_ptr<int64_t>(),
 | 
			
		||||
          query.data_ptr<scalar_t>(),
 | 
			
		||||
          key.data_ptr<scalar_t>(),
 | 
			
		||||
          cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
          rot_dim,
 | 
			
		||||
          query_stride,
 | 
			
		||||
          key_stride,
 | 
			
		||||
          num_heads,
 | 
			
		||||
          num_kv_heads,
 | 
			
		||||
          head_size);
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
 | 
			
		||||
    if (is_neox) {
 | 
			
		||||
      vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
 | 
			
		||||
          positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
 | 
			
		||||
          key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
 | 
			
		||||
          query_stride, key_stride, num_heads, num_kv_heads, head_size);
 | 
			
		||||
    } else {
 | 
			
		||||
      vllm::rotary_embedding_kernel<scalar_t, false>
 | 
			
		||||
          <<<grid, block, 0, stream>>>(
 | 
			
		||||
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
 | 
			
		||||
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
              rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
 | 
			
		||||
              head_size);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
 | 
			
		||||
and process in batched manner.
 | 
			
		||||
*/
 | 
			
		||||
void batched_rotary_embedding(
 | 
			
		||||
  torch::Tensor& positions,         // [batch_size, seq_len] or [num_tokens]
 | 
			
		||||
  torch::Tensor& query,             // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
 | 
			
		||||
  torch::Tensor& key,               // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
 | 
			
		||||
  int head_size,
 | 
			
		||||
  torch::Tensor& cos_sin_cache,     // [max_position, rot_dim]
 | 
			
		||||
  bool is_neox,
 | 
			
		||||
  int rot_dim,
 | 
			
		||||
  torch::Tensor& cos_sin_cache_offsets // [num_tokens]
 | 
			
		||||
    torch::Tensor& positions,  // [batch_size, seq_len] or [num_tokens]
 | 
			
		||||
    torch::Tensor& query,  // [batch_size, seq_len, num_heads * head_size] or
 | 
			
		||||
                           // [num_tokens, num_heads * head_size]
 | 
			
		||||
    torch::Tensor& key,    // [batch_size, seq_len, num_kv_heads * head_size] or
 | 
			
		||||
                           // [num_tokens, num_kv_heads * head_size]
 | 
			
		||||
    int64_t head_size,
 | 
			
		||||
    torch::Tensor& cos_sin_cache,  // [max_position, rot_dim]
 | 
			
		||||
    bool is_neox, int64_t rot_dim,
 | 
			
		||||
    torch::Tensor& cos_sin_cache_offsets  // [num_tokens]
 | 
			
		||||
) {
 | 
			
		||||
  int64_t num_tokens = cos_sin_cache_offsets.size(0);
 | 
			
		||||
  int num_heads = query.size(-1) / head_size;
 | 
			
		||||
@ -188,39 +180,24 @@ void batched_rotary_embedding(
 | 
			
		||||
  int64_t key_stride = key.stride(-2);
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
 | 
			
		||||
  dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    query.scalar_type(),
 | 
			
		||||
    "rotary_embedding",
 | 
			
		||||
    [&] {
 | 
			
		||||
      if (is_neox) {
 | 
			
		||||
        vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
 | 
			
		||||
          positions.data_ptr<int64_t>(),
 | 
			
		||||
          query.data_ptr<scalar_t>(),
 | 
			
		||||
          key.data_ptr<scalar_t>(),
 | 
			
		||||
          cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
          cos_sin_cache_offsets.data_ptr<int64_t>(),
 | 
			
		||||
          rot_dim,
 | 
			
		||||
          query_stride,
 | 
			
		||||
          key_stride,
 | 
			
		||||
          num_heads,
 | 
			
		||||
          num_kv_heads,
 | 
			
		||||
          head_size);
 | 
			
		||||
      } else {
 | 
			
		||||
        vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
 | 
			
		||||
          positions.data_ptr<int64_t>(),
 | 
			
		||||
          query.data_ptr<scalar_t>(),
 | 
			
		||||
          key.data_ptr<scalar_t>(),
 | 
			
		||||
          cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
          cos_sin_cache_offsets.data_ptr<int64_t>(),
 | 
			
		||||
          rot_dim,
 | 
			
		||||
          query_stride,
 | 
			
		||||
          key_stride,
 | 
			
		||||
          num_heads,
 | 
			
		||||
          num_kv_heads,
 | 
			
		||||
          head_size);
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
 | 
			
		||||
    if (is_neox) {
 | 
			
		||||
      vllm::batched_rotary_embedding_kernel<scalar_t, true>
 | 
			
		||||
          <<<grid, block, 0, stream>>>(
 | 
			
		||||
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
 | 
			
		||||
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
              cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
 | 
			
		||||
              key_stride, num_heads, num_kv_heads, head_size);
 | 
			
		||||
    } else {
 | 
			
		||||
      vllm::batched_rotary_embedding_kernel<scalar_t, false>
 | 
			
		||||
          <<<grid, block, 0, stream>>>(
 | 
			
		||||
              positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
 | 
			
		||||
              key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
 | 
			
		||||
              cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
 | 
			
		||||
              key_stride, num_heads, num_kv_heads, head_size);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -2,3 +2,4 @@
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
 | 
			
		||||
@ -2,3 +2,4 @@
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
 | 
			
		||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
 | 
			
		||||
@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 128) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 256) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 512) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 640) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 768) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 1024) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 1152) \
 | 
			
		||||
@ -27,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 2752) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 2816) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 3072) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 3328) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 3456) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 3584) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 4096) \
 | 
			
		||||
@ -35,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 5504) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 5632) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 6144) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 6400) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 6848) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 6912) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 7168) \
 | 
			
		||||
@ -46,11 +49,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 13696) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 13824) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 14336) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 15360) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 16384) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 20480) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 22016) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 24576) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 27392) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 27648) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 28672) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 32000) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 32256) \
 | 
			
		||||
@ -58,9 +63,91 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 32768) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 33024) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 36864) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 43264) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 49152) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 64000) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 64256) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 64512) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 102400) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 102656) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 102912) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 128000) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 128256) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 128512) \
 | 
			
		||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
 | 
			
		||||
// and vllm/tests/lora/test_punica.py
 | 
			
		||||
 | 
			
		||||
// Used for defining kernels going from the variety of 
 | 
			
		||||
// dim in to the narrow dim out
 | 
			
		||||
    // Using it for the fully sharded column 
 | 
			
		||||
    // parallel LoRA A which splits the rank dim
 | 
			
		||||
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 128, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 256, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 512, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 640, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 768, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 1024, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 1152, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 1280, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 1536, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 1728, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 1792, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 2048, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 2304, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 2560, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 2752, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 2816, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 3072, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 3328, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 3456, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 3584, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 4096, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 4608, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 5120, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 5504, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 5632, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 6144, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 6400, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 6848, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 6912, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 7168, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 8192, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 9216, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 10240, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 11008, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 12288, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 13696, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 13824, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 14336, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 15360, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 16384, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 20480, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 22016, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 24576, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 27392, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 27648, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 28672, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 32000, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 32256, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 32512, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 32768, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 33024, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 36864, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 43264, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 49152, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 64000, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 64256, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 64512, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 102400, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 102656, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 102912, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 128000, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 128256, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, 128512, narrow) \
 | 
			
		||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Keep this in sync with vllm/config::LoRAConfig
 | 
			
		||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
 | 
			
		||||
    FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8)  \
 | 
			
		||||
@ -68,4 +155,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
    FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
 | 
			
		||||
    FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
 | 
			
		||||
    FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
 | 
			
		||||
    FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
 | 
			
		||||
    FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
 | 
			
		||||
    f(in_T, out_T, W_T, 8, 64) \
 | 
			
		||||
    f(in_T, out_T, W_T, 16, 64) \
 | 
			
		||||
    f(in_T, out_T, W_T, 32, 64) \
 | 
			
		||||
    f(in_T, out_T, W_T, 64, 64)
 | 
			
		||||
 | 
			
		||||
// clang-format on
 | 
			
		||||
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
 | 
			
		||||
@ -2,3 +2,4 @@
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
 | 
			
		||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
 | 
			
		||||
@ -2,3 +2,4 @@
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
 | 
			
		||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
 | 
			
		||||
 | 
			
		||||
@ -2,3 +2,4 @@
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +0,0 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user