mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
522 Commits
v0.5.3.pos
...
v0.6.0
Author | SHA1 | Date | |
---|---|---|---|
32e7db2536 | |||
008cf886c9 | |||
77d9e514a2 | |||
e02ce498be | |||
561d6f8077 | |||
d1dec64243 | |||
2ad2e5608e | |||
d3311562fb | |||
ccd7207191 | |||
855c262a6b | |||
2be8ec6e71 | |||
e16fa99a6a | |||
61f4a93d14 | |||
d4db9f53c8 | |||
2188a60c7e | |||
dc0b6066ab | |||
0af3abe3d3 | |||
f1575dc99f | |||
c02638efb3 | |||
652c83b697 | |||
6d646d08a2 | |||
95a178f861 | |||
bd852f2a8b | |||
ec266536b7 | |||
0fbc6696c2 | |||
6e36f4fa6c | |||
dd2a6a82e3 | |||
4ca65a9763 | |||
e2b2aa5a0f | |||
e6a26ed037 | |||
f8d60145b4 | |||
5b86b19954 | |||
5231f0898e | |||
8423aef4c8 | |||
4f5d8446ed | |||
d05f0a9db2 | |||
622f8abff8 | |||
1248e8506a | |||
2684efc467 | |||
058344f89a | |||
98cef6a227 | |||
f97be32d1d | |||
afd39a4511 | |||
2148441fd3 | |||
dc13e99348 | |||
34a0e96d46 | |||
80c7b089b1 | |||
428dd1445e | |||
4abed65c58 | |||
0c785d344d | |||
4664ceaad6 | |||
257afc37c5 | |||
86a677de42 | |||
d78789ac16 | |||
c334b1898b | |||
6b3421567d | |||
3f60f2244e | |||
f205c09854 | |||
ef99a78760 | |||
74d5543ec5 | |||
a7f65c2be9 | |||
4289cad37f | |||
af59df0a10 | |||
ce6bf3a2cf | |||
3cdfe1f38b | |||
fdd9daafa3 | |||
8c56e57def | |||
eeffde1ac0 | |||
e5697d161c | |||
b98cc28f91 | |||
ef9baee3c5 | |||
98c12cffe5 | |||
f52a43a8b9 | |||
e3580537a4 | |||
f508e03e7f | |||
51f86bf487 | |||
c166e7e43e | |||
bc6e42a9b1 | |||
fab5f53e2d | |||
9c71c97ae2 | |||
5340a2dccf | |||
345be0e244 | |||
fc911880cc | |||
ed6f002d33 | |||
b09c755be8 | |||
42e932c7d4 | |||
076169f603 | |||
9db642138b | |||
6fc4e6e07a | |||
9606c7197d | |||
64cc644425 | |||
39178c7fbc | |||
2eedede875 | |||
015e6cc252 | |||
760e9f71a8 | |||
05826c887b | |||
dd9857f5fa | |||
665304092d | |||
2deb029d11 | |||
029c71de11 | |||
0b769992ec | |||
1856aff4d6 | |||
70c094ade6 | |||
2059b8d9ca | |||
8aaf3d5347 | |||
80162c44b1 | |||
aab0fcdb63 | |||
ea9fa160e3 | |||
7d9ffa2ae1 | |||
d81abefd2e | |||
8da48e4d95 | |||
6885fde317 | |||
9db93de20c | |||
09c7792610 | |||
f1df5dbfd6 | |||
35ee2ad6b9 | |||
e25fee57c2 | |||
faeddb565d | |||
fc5ebbd1d3 | |||
c01a6cb231 | |||
b903e1ba7f | |||
a152246428 | |||
666ad0aa16 | |||
15310b5101 | |||
57792ed469 | |||
d3b5b98021 | |||
cc0eaf12b1 | |||
955b5191c9 | |||
55d63b1211 | |||
4f419c00a6 | |||
a3fce56b88 | |||
b3856bef7d | |||
8c6f694a79 | |||
eeee1c3b1a | |||
aae74ef95c | |||
cde9183b40 | |||
df1a21131d | |||
7937009a7e | |||
9984605412 | |||
7eebe8ccaa | |||
8678a69ab5 | |||
5844017285 | |||
1ca0d4f86b | |||
dd53c4b023 | |||
970dfdc01d | |||
91f4522cbf | |||
1b32e02648 | |||
f7e3b0c5aa | |||
d3c002eadc | |||
9b73a2f498 | |||
6925cdbeea | |||
53328d7536 | |||
c75363fbc0 | |||
dd3fa0e430 | |||
baaedfdb2d | |||
4506641212 | |||
12e1c65bc9 | |||
b74a125800 | |||
66a9e713a7 | |||
9e51b6a626 | |||
6e4658c7aa | |||
3b682179dd | |||
c6af027a35 | |||
2aa00d59ad | |||
c42590f97a | |||
aae6927be0 | |||
398521ad19 | |||
5288c06aa0 | |||
b6f99a6ffe | |||
ad28a74beb | |||
e6d811dd13 | |||
c4be16e1a7 | |||
3d8a5f063d | |||
f4fc7337bf | |||
0df7ec0b2d | |||
312f761232 | |||
e54ebc2f8f | |||
67e02fa8a4 | |||
43735bf5e1 | |||
da115230fd | |||
7601cb044d | |||
47b65a5508 | |||
dad961ef5c | |||
3ac50b47d0 | |||
df845b2b46 | |||
1a36287b89 | |||
f710fb5265 | |||
ff7ec82c4d | |||
200a2ffa6b | |||
40e1360bb6 | |||
e3b318216d | |||
ab7165f2c7 | |||
0c2fa50b84 | |||
ce143353c6 | |||
bbf55c4805 | |||
1ef13cf92f | |||
832163b875 | |||
e73f76eec6 | |||
d95cc0a55c | |||
5bf45db7df | |||
eed020f673 | |||
7c0b7ea214 | |||
4706eb628e | |||
bae888cb8e | |||
6bd19551b0 | |||
e680349994 | |||
44f26a9466 | |||
37fd47e780 | |||
7759ae958f | |||
9f69856356 | |||
d4f0f17b02 | |||
b3f4e17935 | |||
93478b63d2 | |||
f366f6339b | |||
855866caa9 | |||
7fc23be81c | |||
e837b624f2 | |||
ec724a725e | |||
0e39a33c6d | |||
6fc5b0f249 | |||
9587b050fb | |||
54bd9a03c4 | |||
50b8d08dbd | |||
e165528778 | |||
3b19e39dc5 | |||
4cd7d47fed | |||
f878c8feb0 | |||
b67ae00cdb | |||
9c8e2d1161 | |||
21313e09e3 | |||
f4da5f7b6d | |||
9c1f78d5d6 | |||
fc93e56143 | |||
22b39e11f2 | |||
f55a9aea45 | |||
951fdd66d3 | |||
2ecf7b1757 | |||
3f674a49b5 | |||
70b746efcf | |||
67d115db08 | |||
d3d9cb6e4b | |||
c134a46402 | |||
199adbb7cf | |||
dd164d72f3 | |||
ea49e6a3c8 | |||
97992802f3 | |||
59edd0f134 | |||
a08df8322e | |||
16422ea76f | |||
373538f973 | |||
33e5d7e6b6 | |||
c5c7768264 | |||
b1e5afc3e7 | |||
d3bdfd3ab9 | |||
fb377d7e74 | |||
181abbc27d | |||
00c3d68e45 | |||
e20233d361 | |||
d6e634f3d7 | |||
4d2dc5072b | |||
7025b11d94 | |||
5469146bcc | |||
97a6be95ba | |||
9ba85bc152 | |||
198d6a2898 | |||
774cd1d3bf | |||
91294d56e1 | |||
a046f86397 | |||
4ddc4743d7 | |||
6aa33cb2dd | |||
1137f343aa | |||
9b3e2edd30 | |||
65950e8f58 | |||
cfba4def5d | |||
d2bc4510a4 | |||
24154f8618 | |||
e6e42e4b17 | |||
ec2affa8ae | |||
86ab567bae | |||
f020a6297e | |||
6c8e595710 | |||
02b1988b9f | |||
386087970a | |||
c08e2b3086 | |||
4fb7b52a2c | |||
90bab18f24 | |||
4c5d8e8ea9 | |||
baa240252e | |||
999ef0b917 | |||
5c6c54d67a | |||
933790c209 | |||
70d268a399 | |||
249b88228d | |||
74af2bbd90 | |||
fc7b8d1eef | |||
67abdbb42f | |||
07ab160741 | |||
b4e9528f95 | |||
57b7be0e1c | |||
99b4cf5f23 | |||
e02ac55617 | |||
73388c07a4 | |||
7eb4a51c5f | |||
0fa14907da | |||
5923532e15 | |||
a049b107e2 | |||
8334c39f37 | |||
e904576743 | |||
e14fb22e59 | |||
782e53ab59 | |||
21b9c49aa3 | |||
5fb4a3f678 | |||
757ac70a64 | |||
6dffa4b0a6 | |||
48abee9e54 | |||
746709642c | |||
e53dfd3eaf | |||
6d94420246 | |||
fc1493a01e | |||
311f743831 | |||
469b3bc538 | |||
5223199e03 | |||
fde47d3bc2 | |||
0e12cd67a8 | |||
80cbe10c59 | |||
b764547616 | |||
ab0f5e2823 | |||
564985729a | |||
0f7052bc7e | |||
639159b2a6 | |||
66d617e343 | |||
7b261092de | |||
2385c8f374 | |||
9a3f49ae07 | |||
f9a5600649 | |||
fd95e026e0 | |||
660470e5a3 | |||
8d59dbb000 | |||
5c60c8c423 | |||
00afc78590 | |||
541c1852d3 | |||
a3bbbfa1d8 | |||
1f26efbb3a | |||
9118217f58 | |||
e3c664bfcb | |||
360bd67cf0 | |||
ef527be06c | |||
89b8db6bb2 | |||
789937af2e | |||
dfb1a15dcb | |||
4db5176d97 | |||
4cf1dc39be | |||
6e4852ce28 | |||
8571ac4672 | |||
997cf78308 | |||
57f560aa23 | |||
003f8ee128 | |||
e9630458c7 | |||
82a1b1a82b | |||
c0d8f1636c | |||
cc08fc7225 | |||
7b86e7c9cd | |||
f80ab3521c | |||
16a1cc9bb2 | |||
b1c9aa3daa | |||
179a6a36f2 | |||
83c644fe7e | |||
9fadc7b7a0 | |||
654bc5ca49 | |||
825b044863 | |||
44dcb52e39 | |||
67d745cc68 | |||
99d7cabd7b | |||
fb2c1c86c1 | |||
0c25435daa | |||
a0d164567c | |||
04e5583425 | |||
8c025fa703 | |||
69ea15e5cc | |||
ed812a73fa | |||
708989341e | |||
22e718ff1a | |||
05308891e2 | |||
a8d604ca2a | |||
b482b9a5b1 | |||
806949514a | |||
c16eaac500 | |||
db35186391 | |||
660dea1235 | |||
cf2a1a4d9d | |||
252357793d | |||
3bb4b1e4cd | |||
954f7305a1 | |||
6ce01f3066 | |||
6a11fdfbb8 | |||
805a8a75f2 | |||
562e580abc | |||
fc912e0886 | |||
f4fd390f5d | |||
fb3db61688 | |||
2dd34371a6 | |||
7e0861bd0b | |||
a72a424b3e | |||
c8a7e93273 | |||
3c10591ef2 | |||
0437492ea9 | |||
630dd9e0ae | |||
23993a7997 | |||
1d2e7fb73f | |||
7ecee34321 | |||
7eb0cb4a14 | |||
a0dce9383a | |||
35e9c12bfa | |||
93548eb37e | |||
460c1884e3 | |||
bd70013407 | |||
2ee8d3ba55 | |||
daed30c4a9 | |||
2f4e108f75 | |||
6512937de1 | |||
c0644cf9ce | |||
533d1932d2 | |||
9f0e69b653 | |||
f230cc2ca6 | |||
da1f7cc12a | |||
c32ab8be1a | |||
fb4f530bf5 | |||
79319cedfa | |||
40c27a7cbb | |||
6ca8031e71 | |||
d7a299edaa | |||
052b6f8ca4 | |||
5895b24677 | |||
cbbc904470 | |||
5cf9254a9c | |||
f058403683 | |||
c66c7f86ac | |||
6e063ea35b | |||
af647fb8b3 | |||
61a97c32f6 | |||
4fbf4aa128 | |||
aae6d36f7e | |||
9f69d8245a | |||
9a7e2d0534 | |||
7f8d612d24 | |||
60d1c6e584 | |||
db9e5708a9 | |||
766435e660 | |||
7cbd9ec7a9 | |||
3eeb148f46 | |||
b1366a9534 | |||
75acdaa4b6 | |||
fad5576c58 | |||
f954d0715c | |||
1ad86acf17 | |||
ecb33a28cb | |||
a57d75821c | |||
925de97e05 | |||
aa46953a20 | |||
593e79e733 | |||
c53041ae3b | |||
52f07e3dec | |||
14dbd5a767 | |||
ed94e4f427 | |||
3c3012398e | |||
ced36cd89b | |||
969d032265 | |||
55712941e5 | |||
981b0d5673 | |||
d09b94ca58 | |||
bb5494676f | |||
b5f49ee55b | |||
150a1ffbfd | |||
281977bd6e | |||
3bbb4936dc | |||
aa4867791e | |||
71734f1bf2 | |||
50704f52c4 | |||
07278c37dd | |||
85ad7e2d01 | |||
89a84b0bb7 | |||
084a01fd35 | |||
062a1d0fab | |||
2eb9f4ff26 | |||
443c7cf4cf | |||
1adddb14bf | |||
b7215de2c5 | |||
f3ff63c3f4 | |||
cd7edc4e87 | |||
6a1e25b151 | |||
95db75de64 | |||
65b1f121c8 | |||
889da130e7 | |||
b75e314fff | |||
316a41ac1d | |||
0310029a2f | |||
309aaef825 | |||
9e169a4c61 | |||
5689e256ba | |||
740374d456 | |||
d88c458f44 | |||
421e218b37 | |||
5448f67635 | |||
0e63494cf3 | |||
ee812580f7 | |||
40468b13fa | |||
2cf0df3381 | |||
545146349c | |||
f4f8a9d892 | |||
b570811706 | |||
ccc4a73257 | |||
0a740a11ba | |||
c882a7f5b3 | |||
5e8ca973eb | |||
87525fab92 | |||
2f808e69ab | |||
01c16ede6b | |||
72fc704803 | |||
1bedf210e3 | |||
507ef787d8 | |||
58f53034ad | |||
0eb0757bef |
@ -1,36 +1,43 @@
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
MAX_SIZE_MB = 200
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
"""Print the top 10 largest files in the given zip file."""
|
||||
with zipfile.ZipFile(zip_file, 'r') as z:
|
||||
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
|
||||
file_sizes.sort(key=lambda x: x[1], reverse=True)
|
||||
for f, size in file_sizes[:10]:
|
||||
print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
|
||||
print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.")
|
||||
|
||||
|
||||
def check_wheel_size(directory):
|
||||
"""Check the size of .whl files in the given directory."""
|
||||
for root, _, files in os.walk(directory):
|
||||
for f in files:
|
||||
if f.endswith(".whl"):
|
||||
wheel_path = os.path.join(root, f)
|
||||
wheel_size = os.path.getsize(wheel_path)
|
||||
wheel_size_mb = wheel_size / (1024 * 1024)
|
||||
if wheel_size_mb > MAX_SIZE_MB:
|
||||
print(
|
||||
f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
|
||||
f"compare to the allowed size ({MAX_SIZE_MB} MB).")
|
||||
for file_name in files:
|
||||
if file_name.endswith(".whl"):
|
||||
wheel_path = os.path.join(root, file_name)
|
||||
wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024)
|
||||
if wheel_size_mb > VLLM_MAX_SIZE_MB:
|
||||
print(f"Not allowed: Wheel {wheel_path} is larger "
|
||||
f"({wheel_size_mb:.2f} MB) than the limit "
|
||||
f"({VLLM_MAX_SIZE_MB} MB).")
|
||||
print_top_10_largest_files(wheel_path)
|
||||
return 1
|
||||
else:
|
||||
print(f"Wheel {wheel_path} is within the allowed size "
|
||||
f"({wheel_size_mb} MB).")
|
||||
f"({wheel_size_mb:.2f} MB).")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(check_wheel_size(sys.argv[1]))
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python check-wheel-size.py <directory>")
|
||||
sys.exit(1)
|
||||
|
||||
directory = sys.argv[1]
|
||||
sys.exit(check_wheel_size(directory))
|
@ -9,3 +9,4 @@ tasks:
|
||||
value: 0.664
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
||||
trust_remote_code: True
|
11
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
11
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.419
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.416
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
11
.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml
Normal file
11
.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "mgoin/Minitron-4B-Base-FP8"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.233
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.236
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.578
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.585
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -1,7 +1,9 @@
|
||||
Meta-Llama-3-8B-Instruct.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base-FP8.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
@ -14,7 +14,7 @@ import lm_eval
|
||||
import numpy
|
||||
import yaml
|
||||
|
||||
RTOL = 0.02
|
||||
RTOL = 0.05
|
||||
TEST_DATA_FILE = os.environ.get(
|
||||
"LM_EVAL_TEST_DATA_FILE",
|
||||
".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml")
|
||||
@ -23,9 +23,12 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1)
|
||||
|
||||
|
||||
def launch_lm_eval(eval_config):
|
||||
trust_remote_code = eval_config.get('trust_remote_code', False)
|
||||
|
||||
model_args = f"pretrained={eval_config['model_name']}," \
|
||||
f"tensor_parallel_size={TP_SIZE}," \
|
||||
f"add_bos_token=true"
|
||||
f"add_bos_token=true," \
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
|
@ -3,30 +3,52 @@
|
||||
|
||||
## Introduction
|
||||
|
||||
This directory contains the performance benchmarking CI for vllm.
|
||||
The goal is to help developers know the impact of their PRs on the performance of vllm.
|
||||
This directory contains two sets of benchmark for vllm.
|
||||
- Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance
|
||||
- Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm.
|
||||
|
||||
This benchmark will be *triggered* upon:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label.
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for more GPUs is comming later), with different models.
|
||||
See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results.
|
||||
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
**For benchmarking developers**: please try your best to constraint the duration of benchmarking to less than 1.5 hr so that it won't take forever to run.
|
||||
**For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run.
|
||||
|
||||
|
||||
## Configuring the workload
|
||||
## Nightly benchmark quick overview
|
||||
|
||||
The benchmarking workload contains three parts:
|
||||
- Latency tests in `latency-tests.json`.
|
||||
- Throughput tests in `throughput-tests.json`.
|
||||
- Serving tests in `serving-tests.json`.
|
||||
**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B.
|
||||
|
||||
See [descriptions.md](tests/descriptions.md) for detailed descriptions.
|
||||
**Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy.
|
||||
|
||||
### Latency test
|
||||
**Benchmarking Duration**: about 3.5hrs.
|
||||
|
||||
|
||||
|
||||
## Trigger the benchmark
|
||||
|
||||
Performance benchmark will be triggered when:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label AND `ready` label.
|
||||
|
||||
Nightly benchmark will be triggered when:
|
||||
- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label.
|
||||
|
||||
|
||||
|
||||
|
||||
## Performance benchmark details
|
||||
|
||||
|
||||
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
|
||||
|
||||
#### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
|
||||
@ -47,19 +69,19 @@ Here is an example of one test inside `latency-tests.json`:
|
||||
|
||||
In this example:
|
||||
- The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`.
|
||||
- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15`
|
||||
- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15`
|
||||
|
||||
Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly.
|
||||
|
||||
WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file.
|
||||
|
||||
|
||||
### Throughput test
|
||||
#### Throughput test
|
||||
The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`.
|
||||
|
||||
The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot.
|
||||
|
||||
### Serving test
|
||||
#### Serving test
|
||||
We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example:
|
||||
|
||||
```
|
||||
@ -96,9 +118,36 @@ The number of this test is less stable compared to the delay and latency benchma
|
||||
|
||||
WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`.
|
||||
|
||||
## Visualizing the results
|
||||
#### Visualizing the results
|
||||
The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results.
|
||||
You can find the result presented as a table inside the `buildkite/performance-benchmark` job page.
|
||||
If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
|
||||
|
||||
|
||||
|
||||
## Nightly test details
|
||||
|
||||
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
|
||||
|
||||
|
||||
#### Workflow
|
||||
|
||||
- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines.
|
||||
- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container.
|
||||
- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark.
|
||||
- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
|
||||
#### Nightly tests
|
||||
|
||||
In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark.
|
||||
|
||||
#### Docker containers
|
||||
|
||||
The docker containers for benchmarking are specified in `nightly-pipeline.yaml`.
|
||||
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`.
|
||||
|
||||
WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git).
|
||||
|
||||
|
@ -21,7 +21,7 @@ steps:
|
||||
containers:
|
||||
- image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
|
||||
command:
|
||||
- bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
|
||||
- bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 8
|
||||
@ -42,20 +42,20 @@ steps:
|
||||
- name: devshm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
- label: "H100"
|
||||
agents:
|
||||
queue: H100
|
||||
plugins:
|
||||
- docker#v5.11.0:
|
||||
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
|
||||
command:
|
||||
- bash
|
||||
- .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
|
||||
mount-buildkite-agent: true
|
||||
propagate-environment: true
|
||||
ipc: host
|
||||
gpus: all
|
||||
environment:
|
||||
- VLLM_USAGE_SOURCE
|
||||
- HF_TOKEN
|
||||
# - label: "H100"
|
||||
# agents:
|
||||
# queue: H100
|
||||
# plugins:
|
||||
# - docker#v5.11.0:
|
||||
# image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
|
||||
# command:
|
||||
# - bash
|
||||
# - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
|
||||
# mount-buildkite-agent: true
|
||||
# propagate-environment: true
|
||||
# ipc: host
|
||||
# gpus: all
|
||||
# environment:
|
||||
# - VLLM_USAGE_SOURCE
|
||||
# - HF_TOKEN
|
||||
|
||||
|
@ -1,47 +1,42 @@
|
||||
|
||||
## Latency tests
|
||||
|
||||
This test suite aims to test vllm's end-to-end latency under a controlled setup.
|
||||
|
||||
- Input length: 32 tokens.
|
||||
- Output length: 128 tokens.
|
||||
- Batch size: fixed (8).
|
||||
- Models: llama-3 8B, llama-3 70B, mixtral 8x7B.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||
|
||||
### Latency benchmarking results
|
||||
|
||||
{latency_tests_markdown_table}
|
||||
|
||||
## Throughput tests
|
||||
|
||||
This test suite aims to test vllm's throughput.
|
||||
## Throughput tests
|
||||
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||
- Models: llama-3 8B, llama-3 70B, mixtral 8x7B.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- Evaluation metrics: throughput.
|
||||
|
||||
### Throughput benchmarking results
|
||||
|
||||
{throughput_tests_markdown_table}
|
||||
|
||||
## Serving tests
|
||||
|
||||
This test suite aims to test vllm's real serving metrics.
|
||||
## Serving tests
|
||||
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||
- Models: llama-3 8B, llama-3 70B, mixtral 8x7B.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B, under QPS 2
|
||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||
|
||||
### Serving benchmarking results
|
||||
|
||||
{serving_tests_markdown_table}
|
||||
|
||||
|
||||
## json version of the benchmarking tables
|
||||
|
||||
This section contains the data of the markdown tables above in JSON format.
|
@ -174,8 +174,8 @@ if __name__ == "__main__":
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
|
||||
results = read_markdown(
|
||||
"../.buildkite/nightly-benchmarks/tests/descriptions.md")
|
||||
results = read_markdown("../.buildkite/nightly-benchmarks/" +
|
||||
"performance-benchmarks-descriptions.md")
|
||||
results = results.format(
|
||||
latency_tests_markdown_table=latency_md_table,
|
||||
throughput_tests_markdown_table=throughput_md_table,
|
||||
|
@ -34,6 +34,15 @@ check_hf_token() {
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_sharegpt_downloaded() {
|
||||
local FILE=ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
if [ ! -f "$FILE" ]; then
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE
|
||||
else
|
||||
echo "$FILE already exists."
|
||||
fi
|
||||
}
|
||||
|
||||
json2args() {
|
||||
# transforms the JSON string to command line args, and '_' is replaced to '-'
|
||||
# example:
|
||||
@ -59,40 +68,38 @@ wait_for_server() {
|
||||
done' && return 0 || return 1
|
||||
}
|
||||
|
||||
kill_gpu_processes() {
|
||||
# kill all processes on GPU.
|
||||
pids=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)
|
||||
if [ -z "$pids" ]; then
|
||||
echo "No GPU processes found."
|
||||
kill_processes_launched_by_current_bash() {
|
||||
# Kill all python processes launched from current bash script
|
||||
current_shell_pid=$$
|
||||
processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}')
|
||||
if [ -n "$processes" ]; then
|
||||
echo "Killing the following processes matching '$1':"
|
||||
echo "$processes"
|
||||
echo "$processes" | xargs kill -9
|
||||
else
|
||||
for pid in $pids; do
|
||||
kill -9 "$pid"
|
||||
echo "Killed process with PID: $pid"
|
||||
done
|
||||
|
||||
echo "All GPU processes have been killed."
|
||||
echo "No processes found matching '$1'."
|
||||
fi
|
||||
}
|
||||
|
||||
# Sometimes kill with pid doesn't work properly, we can also kill all process running python or python3
|
||||
# since we are in container anyway
|
||||
pkill -9 -f python
|
||||
pkill -9 -f python3
|
||||
kill_gpu_processes() {
|
||||
|
||||
# waiting for GPU processes to be fully killed
|
||||
# loop while nvidia-smi returns any processes
|
||||
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
|
||||
ps -aux
|
||||
lsof -t -i:8000 | xargs -r kill -9
|
||||
pkill -f pt_main_thread
|
||||
# this line doesn't work now
|
||||
# ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9
|
||||
pkill -f python3
|
||||
pkill -f /usr/bin/python3
|
||||
|
||||
|
||||
# wait until GPU memory usage smaller than 1GB
|
||||
while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do
|
||||
sleep 1
|
||||
echo "Waiting for GPU processes to be killed"
|
||||
done
|
||||
|
||||
# remove vllm config file
|
||||
rm -rf ~/.config/vllm
|
||||
|
||||
# Print the GPU memory usage
|
||||
# so that we know if all GPU processes are killed.
|
||||
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
|
||||
# The memory usage should be 0 MB.
|
||||
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
|
||||
}
|
||||
|
||||
upload_to_buildkite() {
|
||||
@ -110,7 +117,7 @@ upload_to_buildkite() {
|
||||
fi
|
||||
|
||||
# Use the determined command to annotate and upload artifacts
|
||||
$BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md
|
||||
$BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md
|
||||
$BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*"
|
||||
}
|
||||
|
||||
@ -162,7 +169,7 @@ run_latency_tests() {
|
||||
latency_command: $latency,
|
||||
gpu_type: $gpu
|
||||
}')
|
||||
echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands"
|
||||
echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands"
|
||||
|
||||
# run the benchmark
|
||||
eval "$latency_command"
|
||||
@ -172,7 +179,6 @@ run_latency_tests() {
|
||||
done
|
||||
}
|
||||
|
||||
|
||||
run_throughput_tests() {
|
||||
# run throughput tests using `benchmark_throughput.py`
|
||||
# $1: a json file specifying throughput test cases
|
||||
@ -220,7 +226,7 @@ run_throughput_tests() {
|
||||
throughput_command: $command,
|
||||
gpu_type: $gpu
|
||||
}')
|
||||
echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands"
|
||||
echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands"
|
||||
|
||||
# run the benchmark
|
||||
eval "$throughput_command"
|
||||
@ -252,7 +258,6 @@ run_serving_tests() {
|
||||
continue
|
||||
fi
|
||||
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r '.server_parameters')
|
||||
client_params=$(echo "$params" | jq -r '.client_parameters')
|
||||
@ -330,7 +335,7 @@ run_serving_tests() {
|
||||
client_command: $client,
|
||||
gpu_type: $gpu
|
||||
}')
|
||||
echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands"
|
||||
echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands"
|
||||
|
||||
done
|
||||
|
||||
@ -347,6 +352,7 @@ main() {
|
||||
# dependencies
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
(which jq) || (apt-get update && apt-get -y install jq)
|
||||
(which lsof) || (apt-get update && apt-get install -y lsof)
|
||||
|
||||
# get the current IP address, required by benchmark_serving.py
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
@ -355,7 +361,7 @@ main() {
|
||||
|
||||
# prepare for benchmarking
|
||||
cd benchmarks || exit 1
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
ensure_sharegpt_downloaded
|
||||
declare -g RESULTS_FOLDER=results/
|
||||
mkdir -p $RESULTS_FOLDER
|
||||
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/
|
||||
@ -365,7 +371,6 @@ main() {
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
||||
|
||||
|
||||
# postprocess benchmarking results
|
||||
pip install tabulate pandas
|
||||
python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py
|
@ -2,7 +2,7 @@
|
||||
{
|
||||
"test_name": "latency_llama8B_tp1",
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
@ -12,7 +12,7 @@
|
||||
{
|
||||
"test_name": "latency_llama70B_tp4",
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num-iters-warmup": 5,
|
||||
|
@ -3,7 +3,7 @@
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"swap_space": 16,
|
||||
"disable_log_stats": "",
|
||||
@ -11,7 +11,7 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
@ -22,7 +22,7 @@
|
||||
"test_name": "serving_llama70B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"swap_space": 16,
|
||||
"disable_log_stats": "",
|
||||
@ -30,7 +30,7 @@
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
@ -55,5 +55,26 @@
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama70B_tp4_sharegpt_specdecode",
|
||||
"qps_list": [2],
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"disable_log_requests": "",
|
||||
"tensor_parallel_size": 4,
|
||||
"swap_space": 16,
|
||||
"speculative_model": "turboderp/Qwama-0.5B-Instruct",
|
||||
"num_speculative_tokens": 4,
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
"use_v2_block_manager": ""
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -2,7 +2,7 @@
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp1",
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-8B",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
@ -13,7 +13,7 @@
|
||||
{
|
||||
"test_name": "throughput_llama70B_tp4",
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
|
@ -1,9 +1,27 @@
|
||||
steps:
|
||||
- label: "Build wheel - CUDA {{matrix.cuda_version}}"
|
||||
- label: "Build wheel - CUDA 12.1"
|
||||
agents:
|
||||
queue: cpu_queue
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
# rename the files to change linux -> manylinux1
|
||||
- "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
|
||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
- block: "Build CUDA 11.8 wheel"
|
||||
key: block-build-cu118-wheel
|
||||
|
||||
- label: "Build wheel - CUDA 11.8"
|
||||
depends_on: block-build-cu118-wheel
|
||||
agents:
|
||||
queue: cpu_queue
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
# rename the files to change linux -> manylinux1
|
||||
@ -12,8 +30,3 @@ steps:
|
||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
matrix:
|
||||
setup:
|
||||
cuda_version:
|
||||
- "11.8.0"
|
||||
- "12.1.0"
|
||||
|
48
.buildkite/run-amd-test.sh
Normal file → Executable file
48
.buildkite/run-amd-test.sh
Normal file → Executable file
@ -1,5 +1,5 @@
|
||||
# This script runs test inside the corresponding ROCm docker container.
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# Print ROCm version
|
||||
echo "--- Confirming Clean Initial State"
|
||||
@ -55,7 +55,7 @@ while true; do
|
||||
done
|
||||
|
||||
echo "--- Pulling container"
|
||||
image_name="rocmshared/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||
docker pull ${image_name}
|
||||
|
||||
@ -70,15 +70,51 @@ HF_CACHE="$(realpath ~)/huggingface"
|
||||
mkdir -p ${HF_CACHE}
|
||||
HF_MOUNT="/root/.cache/huggingface"
|
||||
|
||||
docker run \
|
||||
commands=$@
|
||||
PARALLEL_JOB_COUNT=8
|
||||
# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs.
|
||||
if [[ $commands == *"--shard-id="* ]]; then
|
||||
for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do
|
||||
#replace shard arguments
|
||||
commands=${@//"--shard-id= "/"--shard-id=${GPU} "}
|
||||
commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "}
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=${GPU} \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name} \
|
||||
--name ${container_name}_${GPU} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${@}"
|
||||
|
||||
/bin/bash -c "${commands}" \
|
||||
|& while read -r line; do echo ">>Shard $GPU: $line"; done &
|
||||
PIDS+=($!)
|
||||
done
|
||||
#wait for all processes to finish and collect exit codes
|
||||
for pid in ${PIDS[@]}; do
|
||||
wait ${pid}
|
||||
STATUS+=($?)
|
||||
done
|
||||
for st in ${STATUS[@]}; do
|
||||
if [[ ${st} -ne 0 ]]; then
|
||||
echo "One of the processes failed with $st"
|
||||
exit ${st}
|
||||
fi
|
||||
done
|
||||
else
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HIP_VISIBLE_DEVICES=0 \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${commands}"
|
||||
fi
|
||||
|
@ -3,26 +3,43 @@
|
||||
set -ex
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f Dockerfile.cpu .
|
||||
docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
|
||||
numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu .
|
||||
numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
|
||||
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
|
||||
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
|
||||
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2
|
||||
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2
|
||||
|
||||
# offline inference
|
||||
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
|
||||
docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test bash -c "cd tests;
|
||||
pip install pytest Pillow protobuf
|
||||
cd ../
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
|
||||
docker exec cpu-test bash -c "
|
||||
pip install pytest matplotlib einops transformers_stream_generator
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \
|
||||
--ignore=tests/models/test_oot_registration.py \
|
||||
--ignore=tests/models/test_registry.py \
|
||||
--ignore=tests/models/test_fp8.py \
|
||||
--ignore=tests/models/test_jamba.py \
|
||||
--ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test bash -c "
|
||||
export VLLM_CPU_KVCACHE_SPACE=10
|
||||
export VLLM_CPU_OMP_THREADS_BIND=48-92
|
||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--model facebook/opt-125m \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
||||
|
@ -12,5 +12,4 @@ remove_docker_container
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
|
||||
python3 /workspace/vllm/examples/offline_inference_tpu.py
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||
|
@ -5,11 +5,49 @@
|
||||
# https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2
|
||||
# to generate the final pipeline yaml file.
|
||||
|
||||
# Documentation
|
||||
# label(str): the name of the test. emoji allowed.
|
||||
# fast_check(bool): whether to run this on each commit on fastcheck pipeline.
|
||||
# fast_check_only(bool): run this test on fastcheck pipeline only
|
||||
# command(str): the single command to run for tests. incompatible with commands.
|
||||
# commands(list): the list of commands to run for test. incompatbile with command.
|
||||
# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd]
|
||||
# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100
|
||||
# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4.
|
||||
# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host,
|
||||
# in this case, commands must be specified. the first command runs on first host, the second
|
||||
# command runs on the second host.
|
||||
# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests
|
||||
# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run.
|
||||
|
||||
# When adding a test
|
||||
# - If the test belong to an existing group, add it there
|
||||
# - If the test is short, add to any existing step
|
||||
# - If the test takes more than 10min, then it is okay to create a new step.
|
||||
# Note that all steps execute in parallel.
|
||||
|
||||
steps:
|
||||
- label: Async Engine, Inputs, Utils, Worker Test
|
||||
##### fast check tests #####
|
||||
|
||||
- label: Documentation Build # 2min
|
||||
working_dir: "/vllm-workspace/test_docs/docs"
|
||||
fast_check: true
|
||||
fast_check_only: true
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r requirements-docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
||||
# Check API reference (if it fails, you may have missing mock imports)
|
||||
- grep \"sig sig-object py\" build/html/dev/sampling_params.html
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 15min
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/async_engine
|
||||
- tests/test_inputs
|
||||
- tests/multimodal
|
||||
- tests/test_utils
|
||||
- tests/worker
|
||||
commands:
|
||||
- pytest -v -s async_engine # Async Engine
|
||||
- pytest -v -s test_inputs.py
|
||||
@ -17,273 +55,361 @@ steps:
|
||||
- pytest -v -s test_utils.py # Utils
|
||||
- pytest -v -s worker # Worker
|
||||
|
||||
- label: Tensorizer, Metrics, Tracing Test
|
||||
fast_check: true
|
||||
fast_check_only: true
|
||||
commands:
|
||||
- apt-get install -y curl libsodium23 && pytest -v -s tensorizer_loader # Tensorizer
|
||||
- pytest -v -s metrics # Metrics
|
||||
- "pip install \
|
||||
opentelemetry-sdk \
|
||||
opentelemetry-api \
|
||||
opentelemetry-exporter-otlp \
|
||||
opentelemetry-semantic-conventions-ai" # Tracing
|
||||
- pytest -v -s tracing
|
||||
|
||||
- label: Regression Test
|
||||
mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
command: pytest -v -s test_regression.py
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: AsyncEngine Test
|
||||
- label: Basic Correctness Test # 30min
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s async_engine
|
||||
|
||||
- label: Basic Correctness Test
|
||||
mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness
|
||||
commands:
|
||||
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Core Test
|
||||
|
||||
- label: Core Test # 10min
|
||||
mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/core
|
||||
- vllm/distributed
|
||||
- tests/core
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
- pytest -v -s distributed/test_parallel_state.py
|
||||
|
||||
- label: Distributed Comm Ops Test
|
||||
- label: Entrypoints Test # 20min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
#mirror_hardwares: [amd]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s distributed/test_comm_ops.py
|
||||
- pytest -v -s distributed/test_shm_broadcast.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
|
||||
- label: 2 Node Tests (4 GPUs in total)
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
num_nodes: 2
|
||||
commands:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
|
||||
- label: Distributed Tests (2 GPUs)
|
||||
mirror_hardwares: [amd]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
|
||||
- label: Distributed Tests (4 GPUs)
|
||||
#mirror_hardwares: [amd]
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/core/
|
||||
- tests/distributed
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
|
||||
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_COMPILED_DAG=1 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
|
||||
- label: Pipeline Parallelism Test
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/metrics
|
||||
- tests/tracing
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
'opentelemetry-sdk>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-api>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \
|
||||
'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'"
|
||||
- pytest -v -s tracing
|
||||
|
||||
- label: Engine Test
|
||||
##### fast check tests #####
|
||||
##### 1 GPU test #####
|
||||
|
||||
- label: Regression Test # 5min
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
command: pytest -v -s test_regression.py
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
- tests/tokenization
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
|
||||
- label: Entrypoints Test
|
||||
fast_check: true
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm
|
||||
- pytest -v -s entrypoints/openai
|
||||
|
||||
- label: Examples Test
|
||||
- label: Examples Test # 12min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
mirror_hardwares: [amd]
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
- examples/
|
||||
commands:
|
||||
# install aws cli for llava_example.py
|
||||
# install tensorizer for tensorize_vllm_model.py
|
||||
- pip install awscli tensorizer
|
||||
- pip install awscli tensorizer # for llava example and tensorizer test
|
||||
- python3 offline_inference.py
|
||||
- python3 cpu_offload.py
|
||||
- python3 offline_inference_chat.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 llava_example.py
|
||||
- python3 offline_inference_vision_language.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
|
||||
- label: Inputs Test
|
||||
#mirror_hardwares: [amd]
|
||||
- label: Models Test # 1hr10min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py
|
||||
|
||||
- label: Kernels Test %N
|
||||
#mirror_hardwares: [amd]
|
||||
- label: torch compile integration test
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
- pytest -v -s ./compile/test_full_graph.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
|
||||
- label: Models Test
|
||||
|
||||
- label: Vision Language Models Test # 42min
|
||||
#mirror_hardwares: [amd]
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s models -m \"not vlm\"
|
||||
|
||||
- label: Vision Language Models Test
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pytest -v -s models -m vlm
|
||||
|
||||
- label: Prefix Caching Test
|
||||
mirror_hardwares: [amd]
|
||||
- label: Prefix Caching Test # 7min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/prefix_caching
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s samplers
|
||||
- label: Samplers Test # 18min
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/sampling_metadata.py
|
||||
- tests/samplers
|
||||
commands:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: LogitsProcessor Test
|
||||
- label: LogitsProcessor Test # 5min
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers
|
||||
- tests/test_logits_processor
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
|
||||
- label: Utils Test
|
||||
commands:
|
||||
- pytest -v -s test_utils.py
|
||||
- pytest -v -s test_embedded_commit.py
|
||||
|
||||
- label: Worker Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s worker
|
||||
|
||||
- label: Speculative decoding tests
|
||||
#mirror_hardwares: [amd]
|
||||
- label: Speculative decoding tests # 22min
|
||||
source_file_dependencies:
|
||||
- vllm/spec_decode
|
||||
- tests/spec_decode
|
||||
commands:
|
||||
# See https://github.com/vllm-project/vllm/issues/5152
|
||||
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
- pytest -v -s spec_decode
|
||||
|
||||
- label: LoRA Test %N
|
||||
#mirror_hardwares: [amd]
|
||||
- label: LoRA Test %N # 30min each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
parallelism: 4
|
||||
|
||||
- label: LoRA Long Context (Distributed)
|
||||
- label: Kernels Test %N # 30min each
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/attention
|
||||
- tests/kernels
|
||||
commands:
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amd]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
- tests/tensorizer_loader
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s tensorizer_loader
|
||||
|
||||
- label: Benchmarks # 9min
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- benchmarks/
|
||||
commands:
|
||||
- pip install aiohttp
|
||||
- bash run-benchmarks.sh
|
||||
|
||||
- label: Quantization Test # 15min
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/quantization
|
||||
command: pytest -v -s quantization
|
||||
|
||||
- label: LM Eval Small Models # 53min
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pip install lm-eval
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
fast_check: false
|
||||
mirror_hardwares: [ amd ]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/tool_use
|
||||
commands:
|
||||
- pytest -v -s tool_use
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
- label: Distributed Comm Ops Test # 7min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/distributed
|
||||
- tests/distributed
|
||||
commands:
|
||||
- pytest -v -s distributed/test_comm_ops.py
|
||||
- pytest -v -s distributed/test_shm_broadcast.py
|
||||
|
||||
- label: 2 Node Tests (4 GPUs in total) # 16min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
num_nodes: 2
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/engine/
|
||||
- vllm/executor/
|
||||
- vllm/model_executor/models/
|
||||
- tests/distributed/
|
||||
commands:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 28min
|
||||
#mirror_hardwares: [amd]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/engine/
|
||||
- vllm/executor/
|
||||
- vllm/model_executor/models/
|
||||
- tests/distributed/
|
||||
commands:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py
|
||||
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 21min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/sampler.py
|
||||
- vllm/sequence.py
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/multi_step_worker.py
|
||||
- vllm/worker/model_runner_base.py
|
||||
- vllm/worker/model_runner.py
|
||||
- vllm/worker/multi_step_model_runner.py
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 23min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/engine/
|
||||
- vllm/executor/
|
||||
- vllm/model_executor/models/
|
||||
- tests/distributed/
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pp_cudagraph.py
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- label: LoRA Long Context (Distributed) # 11min
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora/test_long_context
|
||||
commands:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s -x lora/test_long_context.py
|
||||
|
||||
- label: Tensorizer Test
|
||||
#mirror_hardwares: [amd]
|
||||
- label: Weight Loading Multiple GPU Test
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/weight_loading
|
||||
commands:
|
||||
- apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s tensorizer_loader
|
||||
- bash weight_loading/run_model_weight_loading_test.sh
|
||||
|
||||
- label: Metrics Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s metrics
|
||||
|
||||
- label: Quantization Test
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s quantization
|
||||
##### multi gpus test #####
|
||||
##### A100 test #####
|
||||
|
||||
- label: Tracing Test
|
||||
commands:
|
||||
- "pip install \
|
||||
opentelemetry-sdk \
|
||||
opentelemetry-api \
|
||||
opentelemetry-exporter-otlp \
|
||||
opentelemetry-semantic-conventions-ai"
|
||||
- pytest -v -s tracing
|
||||
|
||||
- label: Benchmarks
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
- pip install aiohttp
|
||||
- bash run-benchmarks.sh
|
||||
|
||||
- label: LM Eval Small Models
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
commands:
|
||||
- pip install lm-eval
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-small.txt -t 1
|
||||
|
||||
- label: LM Eval Large Models
|
||||
gpu: a100
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
commands:
|
||||
- pip install lm-eval
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-large.txt -t 4
|
||||
|
||||
- label: Documentation Build
|
||||
working_dir: "/vllm-workspace/test_docs/docs"
|
||||
fast_check: true
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r requirements-docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
||||
|
||||
- label: Distributed Tests (A100)
|
||||
- label: Distributed Tests (A100) # optional
|
||||
gpu: a100
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s -x lora/test_mixtral.py
|
||||
|
||||
- label: LM Eval Large Models # optional
|
||||
gpu: a100
|
||||
num_gpus: 4
|
||||
working_dir: "/vllm-workspace/.buildkite/lm-eval-harness"
|
||||
source_file_dependencies:
|
||||
- csrc/
|
||||
- vllm/model_executor/layers/quantization
|
||||
commands:
|
||||
- pip install lm-eval
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- bash ./run-tests.sh -c configs/models-large.txt -t 4
|
||||
|
@ -1 +1,4 @@
|
||||
vllm/*.so
|
||||
/.venv
|
||||
/build
|
||||
dist
|
||||
|
7
.github/ISSUE_TEMPLATE/100-documentation.yml
vendored
7
.github/ISSUE_TEMPLATE/100-documentation.yml
vendored
@ -20,3 +20,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
7
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
7
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@ -38,3 +38,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
7
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
7
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@ -36,3 +36,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
14
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
14
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
@ -20,9 +20,14 @@ body:
|
||||
```
|
||||
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
|
||||
value: |
|
||||
<details>
|
||||
<summary>The output of `python collect_env.py`</summary>
|
||||
|
||||
```text
|
||||
The output of `python collect_env.py`
|
||||
Your output of `python collect_env.py` here
|
||||
```
|
||||
|
||||
</details>
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@ -84,3 +89,10 @@ body:
|
||||
- If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect.
|
||||
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
@ -29,3 +29,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
7
.github/ISSUE_TEMPLATE/600-new model.yml
vendored
7
.github/ISSUE_TEMPLATE/600-new model.yml
vendored
@ -31,3 +31,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
@ -50,3 +50,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
7
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
7
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
@ -47,3 +47,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
@ -19,3 +19,10 @@ body:
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
label: Before submitting a new issue...
|
||||
options:
|
||||
- label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.
|
||||
required: true
|
||||
|
23
.github/workflows/add_label_ready_comment.yml
vendored
23
.github/workflows/add_label_ready_comment.yml
vendored
@ -1,23 +0,0 @@
|
||||
name: Add Ready Label on Ready Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
add-ready-label:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/ready')
|
||||
steps:
|
||||
- name: Add label
|
||||
uses: actions/github-script@v5
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
labels: ['ready']
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
11
.github/workflows/clang-format.yml
vendored
11
.github/workflows/clang-format.yml
vendored
@ -30,12 +30,11 @@ jobs:
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||
'csrc/punica/bgmv/bgmv_config.h'
|
||||
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||
'csrc/punica/punica_ops.cu'
|
||||
'csrc/punica/type_convert.h'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
|
33
.github/workflows/mypy.yaml
vendored
33
.github/workflows/mypy.yaml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
@ -25,29 +25,22 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install mypy==1.9.0
|
||||
pip install mypy==1.11.1
|
||||
pip install types-setuptools
|
||||
pip install types-PyYAML
|
||||
pip install types-requests
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
mypy tests --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/attention --config-file pyproject.toml
|
||||
mypy vllm/core --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/inputs --config-file pyproject.toml
|
||||
mypy vllm/logging --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
mypy vllm/model_executor --config-file pyproject.toml
|
||||
mypy vllm/multimodal --config-file pyproject.toml
|
||||
mypy vllm/platforms --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
mypy vllm/lora --follow-imports skip
|
||||
mypy vllm/model_executor --follow-imports skip
|
||||
mypy vllm/prompt_adapter --follow-imports skip
|
||||
mypy vllm/spec_decode --follow-imports skip
|
||||
mypy vllm/worker --follow-imports skip
|
||||
|
||||
|
4
.github/workflows/publish.yml
vendored
4
.github/workflows/publish.yml
vendored
@ -48,8 +48,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: ['2.3.1'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
|
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
2
.github/workflows/ruff.yml
vendored
2
.github/workflows/ruff.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
2
.github/workflows/scripts/build.sh
vendored
2
.github/workflows/scripts/build.sh
vendored
@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
# Make sure punica is built for the release (for LoRA)
|
||||
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
# Make sure release wheels are built for the following architectures
|
||||
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
||||
# Build
|
||||
|
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -87,6 +87,9 @@ target/
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# generated files
|
||||
**/generated/**
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
@ -189,4 +192,4 @@ _build/
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
*.json
|
||||
benchmarks/*.json
|
||||
|
@ -10,6 +10,7 @@ build:
|
||||
|
||||
sphinx:
|
||||
configuration: docs/source/conf.py
|
||||
fail_on_warning: true
|
||||
|
||||
# If using Sphinx, optionally build your docs in additional formats such as PDF
|
||||
formats:
|
||||
|
186
CMakeLists.txt
186
CMakeLists.txt
@ -1,4 +1,4 @@
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
cmake_minimum_required(VERSION 3.26)
|
||||
|
||||
project(vllm_extensions LANGUAGES CXX)
|
||||
|
||||
@ -10,11 +10,14 @@ message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
|
||||
# Suppress potential warnings about unused manually-specified variables
|
||||
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||
|
||||
#
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
#
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
|
||||
|
||||
# Supported NVIDIA architectures.
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
|
||||
@ -32,7 +35,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
|
||||
|
||||
#
|
||||
@ -66,6 +69,39 @@ endif()
|
||||
#
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
# built based on platform/architecture. This is the same logic that
|
||||
# setup.py uses to select which extensions should be built and should
|
||||
# be kept in sync.
|
||||
#
|
||||
# The `default` target makes direct use of cmake easier since knowledge
|
||||
# of which extensions are supported has been factored in, e.g.
|
||||
#
|
||||
# mkdir build && cd build
|
||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
||||
# cmake --build . --target default
|
||||
#
|
||||
add_custom_target(default)
|
||||
message(STATUS "Enabling core extension.")
|
||||
|
||||
# Define _core_C extension
|
||||
# built for (almost) every target platform, (excludes TPU and Neuron)
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/core/torch_bindings.cpp")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_core_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
add_dependencies(default _core_C)
|
||||
|
||||
#
|
||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||
#
|
||||
@ -74,7 +110,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
|
||||
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
|
||||
return()
|
||||
endif()
|
||||
return()
|
||||
endif()
|
||||
@ -132,7 +168,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
# Define other extension targets
|
||||
#
|
||||
|
||||
#
|
||||
@ -156,23 +192,28 @@ set(VLLM_EXT_SRC
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(FetchContent)
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# CUTLASS 3.5.0
|
||||
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
|
||||
# CUTLASS 3.5.1
|
||||
GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
|
||||
GIT_PROGRESS TRUE
|
||||
)
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
@ -191,6 +232,51 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
# Generate sources:
|
||||
execute_process(
|
||||
COMMAND ${CMAKE_COMMAND} -E env
|
||||
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
|
||||
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
|
||||
RESULT_VARIABLE machete_generation_result
|
||||
OUTPUT_VARIABLE machete_generation_output
|
||||
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
|
||||
)
|
||||
|
||||
if (NOT machete_generation_result EQUAL 0)
|
||||
message(FATAL_ERROR "Machete generation failed."
|
||||
" Result: \"${machete_generation_result}\""
|
||||
"\nCheck the log for details: "
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
|
||||
else()
|
||||
message(STATUS "Machete generation completed successfully.")
|
||||
endif()
|
||||
|
||||
# Add machete generated sources
|
||||
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
|
||||
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
|
||||
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
|
||||
|
||||
set_source_files_properties(
|
||||
${MACHETE_GEN_SOURCES}
|
||||
PROPERTIES
|
||||
COMPILE_FLAGS
|
||||
"-gencode arch=compute_90a,code=sm_90a")
|
||||
endif()
|
||||
|
||||
# Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
|
||||
# raise an error if the user that this was built with an incompatible
|
||||
# CUDA version)
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
csrc/quantization/machete/machete_pytorch.cu)
|
||||
endif()
|
||||
|
||||
define_gpu_extension_target(
|
||||
@ -200,7 +286,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
@ -212,6 +298,11 @@ set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/torch_bindings.cpp"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/marlin_moe_ops.cu")
|
||||
endif()
|
||||
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
DESTINATION vllm
|
||||
@ -222,76 +313,7 @@ define_gpu_extension_target(
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
# _punica_C extension
|
||||
#
|
||||
|
||||
set(VLLM_PUNICA_EXT_SRC
|
||||
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||
"csrc/punica/punica_ops.cu"
|
||||
"csrc/punica/torch_bindings.cpp")
|
||||
|
||||
#
|
||||
# Copy GPU compilation flags+update for punica
|
||||
#
|
||||
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
|
||||
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__")
|
||||
|
||||
#
|
||||
# Filter out CUDA architectures < 8.0 for punica.
|
||||
#
|
||||
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
||||
set(VLLM_PUNICA_GPU_ARCHES)
|
||||
foreach(ARCH ${VLLM_GPU_ARCHES})
|
||||
string_to_ver(CODE_VER ${ARCH})
|
||||
if (CODE_VER GREATER_EQUAL 8.0)
|
||||
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
|
||||
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
endif()
|
||||
|
||||
if (VLLM_PUNICA_GPU_ARCHES)
|
||||
define_gpu_extension_target(
|
||||
_punica_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
else()
|
||||
message(WARNING "Unable to create _punica_C target because none of the "
|
||||
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
# built based on platform/architecture. This is the same logic that
|
||||
# setup.py uses to select which extensions should be built and should
|
||||
# be kept in sync.
|
||||
#
|
||||
# The `default` target makes direct use of cmake easier since knowledge
|
||||
# of which extensions are supported has been factored in, e.g.
|
||||
#
|
||||
# mkdir build && cd build
|
||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
||||
# cmake --build . --target default
|
||||
#
|
||||
add_custom_target(default)
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling C extension.")
|
||||
@ -300,12 +322,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
|
||||
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
||||
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
||||
# there are supported target arches.
|
||||
if (VLLM_PUNICA_GPU_ARCHES AND
|
||||
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
|
||||
message(STATUS "Enabling punica extension.")
|
||||
add_dependencies(default _punica_C)
|
||||
endif()
|
||||
endif()
|
||||
|
101
Dockerfile
101
Dockerfile
@ -9,28 +9,23 @@ ARG CUDA_VERSION=12.4.1
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
||||
&& python3 --version
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y git curl sudo
|
||||
|
||||
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
||||
RUN python3 -m pip --version
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
@ -42,13 +37,11 @@ WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-adag.txt requirements-adag.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-cuda.txt
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
RUN python3 -m pip install packaging
|
||||
RUN python3 -m pip install -r requirements-mamba.txt
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
@ -61,23 +54,19 @@ ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
FROM base AS build
|
||||
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-build.txt
|
||||
|
||||
# install compiler cache to speed up compilation leveraging local or remote caching
|
||||
RUN apt-get update -y && apt-get install -y ccache
|
||||
|
||||
# files and directories related to build wheels
|
||||
COPY csrc csrc
|
||||
COPY setup.py setup.py
|
||||
COPY cmake cmake
|
||||
COPY CMakeLists.txt CMakeLists.txt
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-adag.txt requirements-adag.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm vllm
|
||||
@ -88,13 +77,13 @@ ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
# make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
ARG buildkite_commit
|
||||
ENV BUILDKITE_COMMIT=${buildkite_commit}
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
@ -103,12 +92,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& if [ "$CUDA_VERSION" = "11.8.0" ]; then \
|
||||
export SCCACHE_BUCKET=vllm-build-sccache-2; \
|
||||
else \
|
||||
export SCCACHE_BUCKET=vllm-build-sccache; \
|
||||
fi \
|
||||
&& export SCCACHE_REGION=us-west-2 \
|
||||
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||
&& export SCCACHE_IDLE_TIMEOUT=0 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& sccache --show-stats \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
|
||||
@ -122,10 +108,17 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
|
||||
fi
|
||||
|
||||
# check the size of the wheel, we cannot upload wheels larger than 100MB
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
RUN python3 check-wheel-size.py dist
|
||||
|
||||
# Default max size of the wheel is 250MB
|
||||
ARG VLLM_MAX_SIZE_MB=250
|
||||
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||
ARG RUN_WHEEL_CHECK=true
|
||||
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
python3 check-wheel-size.py dist; \
|
||||
else \
|
||||
echo "Skipping wheel size check."; \
|
||||
fi
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
@ -138,45 +131,30 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-dev.txt
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
FROM dev as mamba-builder
|
||||
# max jobs used for build
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
|
||||
WORKDIR /usr/src/mamba
|
||||
|
||||
COPY requirements-mamba.txt requirements-mamba.txt
|
||||
|
||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||
RUN pip --verbose wheel -r requirements-mamba.txt \
|
||||
--no-build-isolation --no-deps --no-cache-dir
|
||||
|
||||
#################### MAMBA Build IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.10
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
||||
&& python3 --version
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip git vim curl libibverbs-dev
|
||||
|
||||
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
||||
RUN python3 -m pip --version
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
@ -189,12 +167,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
. /etc/environment && \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
@ -2,36 +2,49 @@
|
||||
|
||||
FROM ubuntu:22.04 AS cpu-test-1
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
|
||||
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
||||
# intel-openmp provides additional performance improvement vs. openmp
|
||||
# tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects.
|
||||
RUN pip install intel-openmp
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install intel-openmp
|
||||
|
||||
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"
|
||||
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so"
|
||||
|
||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
RUN pip install --upgrade pip \
|
||||
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
|
||||
ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
FROM cpu-test-1 AS build
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
|
||||
COPY ./ ./
|
||||
|
||||
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
||||
ARG VLLM_CPU_DISABLE_AVX512
|
||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=cache,target=/root/.cache/ccache \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \
|
||||
pip install dist/*.whl
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# default base image
|
||||
ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.19.1-ubuntu20.04"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
FROM ubuntu:20.04 AS dev
|
||||
FROM ubuntu:22.04 AS dev
|
||||
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y python3-pip git
|
||||
@ -13,12 +13,15 @@ COPY requirements-common.txt /workspace/vllm/
|
||||
COPY requirements-openvino.txt /workspace/vllm/
|
||||
|
||||
COPY vllm/ /workspace/vllm/vllm
|
||||
COPY csrc/core /workspace/vllm/csrc/core
|
||||
COPY cmake/utils.cmake /workspace/vllm/cmake/
|
||||
COPY CMakeLists.txt /workspace/vllm/
|
||||
COPY setup.py /workspace/vllm/
|
||||
|
||||
# install build requirements
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
|
||||
# build vLLM with OpenVINO backend
|
||||
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
|
||||
|
||||
COPY examples/ /workspace/vllm/examples
|
||||
COPY benchmarks/ /workspace/vllm/benchmarks
|
||||
|
@ -53,10 +53,10 @@ RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(whic
|
||||
# Install torch == 2.5.0 on ROCm
|
||||
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-6.1"*) \
|
||||
python3 -m pip uninstall -y torch torchaudio torchvision \
|
||||
python3 -m pip uninstall -y torch torchvision \
|
||||
&& python3 -m pip install --no-cache-dir --pre \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
torch==2.5.0.dev20240726 \
|
||||
torchvision==0.20.0.dev20240726 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
|
||||
*) ;; esac
|
||||
|
||||
@ -127,19 +127,11 @@ FROM base AS final
|
||||
# Import the vLLM development directory from the build context
|
||||
COPY . .
|
||||
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually remove it so that later steps of numpy upgrade can continue
|
||||
RUN case "$(which python3)" in \
|
||||
*"/opt/conda/envs/py_3.9"*) \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
|
||||
*) ;; esac
|
||||
|
||||
# Package upgrades for useful functionality or to avoid dependency issues
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
|
||||
# Make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
# Workaround for ray >= 2.10.0
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
# Silences the HF Tokenizers warning
|
||||
|
@ -1,20 +1,17 @@
|
||||
ARG NIGHTLY_DATE="20240713"
|
||||
ARG NIGHTLY_DATE="20240828"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install aiohttp separately to avoid build errors.
|
||||
RUN pip install aiohttp
|
||||
# Install NumPy 1 instead of NumPy 2.
|
||||
RUN pip install "numpy<2"
|
||||
# Install the TPU and Pallas dependencies.
|
||||
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
||||
# Build vLLM.
|
||||
COPY . /workspace/vllm
|
||||
ENV VLLM_TARGET_DEVICE="tpu"
|
||||
RUN cd /workspace/vllm && python setup.py develop
|
||||
RUN cd /workspace/vllm && python3 -m pip install -r requirements-tpu.txt
|
||||
RUN cd /workspace/vllm && python3 setup.py develop
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
@ -1,4 +1,5 @@
|
||||
include LICENSE
|
||||
include requirements-adag.txt
|
||||
include requirements-common.txt
|
||||
include requirements-cuda.txt
|
||||
include requirements-rocm.txt
|
||||
|
31
README.md
31
README.md
@ -10,22 +10,23 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> |
|
||||
|
||||
</p>
|
||||
|
||||
|
||||
---
|
||||
|
||||
**The Fifth vLLM Bay Area Meetup (July 24th 5pm-8pm PT)**
|
||||
**vLLM & NVIDIA Triton User Meetup (Monday, September 9, 5pm-9pm PT) at Fort Mason, San Francisco**
|
||||
|
||||
We are excited to announce our fifth vLLM Meetup!
|
||||
Join us to hear the vLLM's recent updates and the upcoming roadmap.
|
||||
Additionally, our collaborators from AWS will be presenting their insights and experiences in deploying vLLM.
|
||||
Register now [here](https://lu.ma/lp0gyjqr) and be part of the event!
|
||||
We are excited to announce our sixth vLLM Meetup, in collaboration with NVIDIA Triton Team.
|
||||
Join us to hear the vLLM's recent update about performance.
|
||||
Register now [here](https://lu.ma/87q3nvnh) and be part of the event!
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
||||
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||
@ -44,10 +45,12 @@ vLLM is fast with:
|
||||
- Efficient management of attention key and value memory with **PagedAttention**
|
||||
- Continuous batching of incoming requests
|
||||
- Fast model execution with CUDA/HIP graph
|
||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
||||
- Optimized CUDA kernels
|
||||
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8.
|
||||
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer.
|
||||
- Speculative decoding
|
||||
- Chunked prefill
|
||||
|
||||
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/3924) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
|
||||
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
@ -56,20 +59,21 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs
|
||||
- (Experimental) Prefix caching support
|
||||
- (Experimental) Multi-lora support
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron.
|
||||
- Prefix caching support
|
||||
- Multi-lora support
|
||||
|
||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||
- Transformer-like LLMs (e.g., Llama)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
||||
- Embedding Models (e.g. E5-Mistral)
|
||||
- Multi-modal LLMs (e.g., LLaVA)
|
||||
|
||||
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
@ -107,6 +111,7 @@ vLLM is a community project. Our compute resources for development and testing a
|
||||
- Roblox
|
||||
- RunPod
|
||||
- Sequoia Capital
|
||||
- Skywork AI
|
||||
- Trainy
|
||||
- UC Berkeley
|
||||
- UC San Diego
|
||||
|
@ -225,8 +225,8 @@ async def async_request_openai_completions(
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"completions"
|
||||
), "OpenAI Completions API URL must end with 'completions'."
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
@ -276,8 +276,9 @@ async def async_request_openai_completions(
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
|
@ -1,8 +1,45 @@
|
||||
"""
|
||||
Benchmark the efficiency of prefix caching.
|
||||
|
||||
This script allows you to benchmark the performance of
|
||||
a model with and without prefix caching using either fixed prompts
|
||||
or prompts sampled from the ShareGPT dataset.
|
||||
|
||||
Fixed example usage:
|
||||
python benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 1 \
|
||||
--repeat-count 100
|
||||
|
||||
ShareGPT example usage:
|
||||
# This command samples 20 prompts with input lengths
|
||||
# between 128 and 256 tokens from the ShareGPT dataset,
|
||||
# then replicates each prompt 5 times.
|
||||
python benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 20 \
|
||||
--repeat-count 5 \
|
||||
--input-length-range 128:256
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
except ImportError:
|
||||
from backend_request_func import get_tokenizer
|
||||
|
||||
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
|
||||
|
||||
|
||||
@ -15,7 +52,83 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
|
||||
print(f"cost time {end_time - start_time}")
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_length_range: Tuple[int, int],
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
min_len, max_len = input_length_range
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
||||
for i in range(len(dataset)):
|
||||
if len(filtered_dataset) == num_requests:
|
||||
break
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = len(completion_token_ids
|
||||
) if fixed_output_len is None else fixed_output_len
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if min_len <= prompt_len <= max_len:
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def repeat_and_sort_requests(requests: List[Tuple[str, int, int]],
|
||||
repeat_count: int,
|
||||
sort: bool = False) -> List[str]:
|
||||
repeated_requests = requests * repeat_count
|
||||
if sort:
|
||||
repeated_requests.sort(key=lambda x: x[1])
|
||||
else:
|
||||
random.shuffle(repeated_requests)
|
||||
return [req[0] for req in repeated_requests]
|
||||
|
||||
|
||||
def main(args):
|
||||
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
|
||||
input_length_range = tuple(map(int, args.input_length_range.split(':')))
|
||||
|
||||
if args.dataset_path is not None:
|
||||
print(f"Start to sample {args.num_prompts} prompts"
|
||||
"from {args.dataset_path}")
|
||||
filtered_datasets = sample_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
input_length_range=input_length_range,
|
||||
fixed_output_len=args.output_len,
|
||||
)
|
||||
else:
|
||||
prompt_len = len(tokenizer(PROMPT).input_ids)
|
||||
filtered_datasets = [(PROMPT, prompt_len, args.output_len)
|
||||
] * args.num_prompts
|
||||
|
||||
llm = LLM(model=args.model,
|
||||
tokenizer_mode='auto',
|
||||
trust_remote_code=True,
|
||||
@ -24,10 +137,13 @@ def main(args):
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
enable_prefix_caching=args.enable_prefix_caching)
|
||||
|
||||
num_prompts = 100
|
||||
prompts = [PROMPT] * num_prompts
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
|
||||
print("Testing filtered datasets")
|
||||
prompts = repeat_and_sort_requests(filtered_datasets,
|
||||
repeat_count=args.repeat_count,
|
||||
sort=args.sort)
|
||||
|
||||
print("------warm up------")
|
||||
test_prefix(
|
||||
llm=llm,
|
||||
@ -45,11 +161,15 @@ def main(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the performance with or without automatic '
|
||||
'prefix caching.')
|
||||
description=
|
||||
'Benchmark the performance with or without automatic prefix caching.')
|
||||
parser.add_argument('--model',
|
||||
type=str,
|
||||
default='baichuan-inc/Baichuan2-13B-Chat')
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
@ -58,5 +178,21 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
help='Use BlockSpaceMangerV2')
|
||||
parser.add_argument('--num-prompts',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of the prompts sampled from dataset")
|
||||
parser.add_argument('--repeat-count',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Number of times to repeat each prompt')
|
||||
parser.add_argument('--sort',
|
||||
action='store_true',
|
||||
help='Sort prompts by input length')
|
||||
parser.add_argument('--input-length-range',
|
||||
type=str,
|
||||
default='128:256',
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -56,20 +56,27 @@ class BenchmarkMetrics:
|
||||
total_input: int
|
||||
total_output: int
|
||||
request_throughput: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
total_token_throughput: float
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
percentiles_ttft_ms: List[Tuple[float, float]]
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
std_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
percentiles_tpot_ms: List[Tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
p99_itl_ms: float
|
||||
percentiles_itl_ms: List[Tuple[float, float]]
|
||||
# E2EL stands for end-to-end latency per request.
|
||||
# It is the time taken on the client side from sending
|
||||
# a request to receiving a complete response.
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: List[Tuple[float, float]]
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
@ -235,6 +242,8 @@ def calculate_metrics(
|
||||
outputs: List[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
total_input = 0
|
||||
@ -242,6 +251,7 @@ def calculate_metrics(
|
||||
itls: List[float] = []
|
||||
tpots: List[float] = []
|
||||
ttfts: List[float] = []
|
||||
e2els: List[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
# We use the tokenizer to count the number of output tokens for all
|
||||
@ -258,6 +268,7 @@ def calculate_metrics(
|
||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||
itls += outputs[i].itl
|
||||
ttfts.append(outputs[i].ttft)
|
||||
e2els.append(outputs[i].latency)
|
||||
completed += 1
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
@ -272,21 +283,29 @@ def calculate_metrics(
|
||||
total_input=total_input,
|
||||
total_output=sum(actual_output_lens),
|
||||
request_throughput=completed / dur_s,
|
||||
input_throughput=total_input / dur_s,
|
||||
output_throughput=sum(actual_output_lens) / dur_s,
|
||||
total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or 0) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_tpot_ms=np.mean(tpots or 0) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
std_tpot_ms=np.std(tpots or 0) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
|
||||
median_tpot_ms=np.median(tpots or 0) * 1000,
|
||||
percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
std_itl_ms=np.std(itls or 0) * 1000,
|
||||
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
mean_e2el_ms=np.median(e2els or 0) * 1000,
|
||||
std_e2el_ms=np.std(e2els or 0) * 1000,
|
||||
median_e2el_ms=np.mean(e2els or 0) * 1000,
|
||||
percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000)
|
||||
for p in selected_percentiles],
|
||||
)
|
||||
|
||||
return metrics, actual_output_lens
|
||||
@ -295,6 +314,7 @@ def calculate_metrics(
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
base_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
@ -302,6 +322,9 @@ async def benchmark(
|
||||
use_beam_search: bool,
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
profile: bool,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -326,6 +349,22 @@ async def benchmark(
|
||||
f"are correctly specified. Error: {test_output.error}")
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
profile_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/start_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
@ -349,6 +388,21 @@ async def benchmark(
|
||||
pbar=pbar)))
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
print("Stopping profiler...")
|
||||
profile_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/stop_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
print("Profiler stopped")
|
||||
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
@ -359,6 +413,8 @@ async def benchmark(
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
@ -370,27 +426,10 @@ async def benchmark(
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
|
||||
metrics.input_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
|
||||
metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
|
||||
n=50,
|
||||
c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
||||
metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||
print("=" * 50)
|
||||
print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):",
|
||||
metrics.total_token_throughput))
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
@ -398,20 +437,8 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"std_ttft_ms": metrics.std_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"std_tpot_ms": metrics.std_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
@ -419,6 +446,47 @@ async def benchmark(
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
# E.g., "TTFT"
|
||||
metric_name: str,
|
||||
# E.g., "Time to First Token"
|
||||
metric_header: str,
|
||||
):
|
||||
# This function print and add statistics of the specified
|
||||
# metric.
|
||||
if metric_attribute_name not in selected_percentile_metrics:
|
||||
return
|
||||
print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_attribute_name}_ms")))
|
||||
print("{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_attribute_name}_ms")))
|
||||
result[f"mean_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"mean_{metric_attribute_name}_ms")
|
||||
result[f"median_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"median_{metric_attribute_name}_ms")
|
||||
result[f"std_{metric_attribute_name}_ms"] = getattr(
|
||||
metrics, f"std_{metric_attribute_name}_ms")
|
||||
for p, value in getattr(metrics,
|
||||
f"percentiles_{metric_attribute_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
|
||||
value))
|
||||
result[f"p{p_word}_{metric_attribute_name}_ms"] = value
|
||||
|
||||
process_one_metric("ttft", "TTFT", "Time to First Token")
|
||||
process_one_metric("tpot", "TPOT",
|
||||
"Time per Output Token (excl. 1st token)")
|
||||
process_one_metric("itl", "ITL", "Inter-token Latency")
|
||||
process_one_metric("e2el", "E2EL", "End-to-end Latency")
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@ -433,8 +501,10 @@ def main(args: argparse.Namespace):
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
base_url = f"{args.base_url}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
@ -506,6 +576,7 @@ def main(args: argparse.Namespace):
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
@ -513,6 +584,11 @@ def main(args: argparse.Namespace):
|
||||
use_beam_search=args.use_beam_search,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
profile=args.profile,
|
||||
selected_percentile_metrics=args.percentile_metrics.split(","),
|
||||
selected_percentiles=[
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
@ -693,6 +769,12 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Specify to disable tqdm progress bar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
help="Use Torch Profiler. The endpoint must be launched with "
|
||||
"VLLM_TORCH_PROFILER_DIR to enable profiler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-result",
|
||||
action="store_true",
|
||||
@ -722,6 +804,23 @@ if __name__ == "__main__":
|
||||
"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||
" format.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percentile-metrics",
|
||||
type=str,
|
||||
default="ttft,tpot,itl",
|
||||
help="Comma-seperated list of selected metrics to report percentils. "
|
||||
"This argument specifies the metrics to report percentiles. "
|
||||
"Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". "
|
||||
"Default value is \"ttft,tpot,itl\".")
|
||||
parser.add_argument(
|
||||
"--metric-percentiles",
|
||||
type=str,
|
||||
default="99",
|
||||
help="Comma-seperated list of percentiles for selected metrics. "
|
||||
"To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". "
|
||||
"Default value is \"99\". "
|
||||
"Use \"--percentile-metrics\" to select metrics.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -6,13 +6,16 @@ import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import uvloop
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -82,8 +85,11 @@ def run_vllm(
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -106,6 +112,9 @@ def run_vllm(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -129,6 +138,93 @@ def run_vllm(
|
||||
return end - start
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
distributed_executor_backend: Optional[str],
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
num_scheduler_steps: int = 1,
|
||||
use_v2_block_manager: bool = False,
|
||||
download_dir: Optional[str] = None,
|
||||
load_format: str = EngineArgs.load_format,
|
||||
disable_async_output_proc: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
load_format=load_format,
|
||||
num_scheduler_steps=num_scheduler_steps,
|
||||
use_v2_block_manager=use_v2_block_manager,
|
||||
disable_async_output_proc=disable_async_output_proc,
|
||||
worker_use_ray=False,
|
||||
engine_use_ray=False,
|
||||
disable_log_requests=True,
|
||||
)
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: List[str] = []
|
||||
sampling_params: List[SamplingParams] = []
|
||||
for prompt, _, output_len in requests:
|
||||
prompts.append(prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_hf(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
@ -224,7 +320,7 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
run_args = [
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
@ -232,7 +328,16 @@ def main(args: argparse.Namespace):
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.distributed_executor_backend,
|
||||
args.gpu_memory_utilization, args.download_dir, args.load_format)
|
||||
args.gpu_memory_utilization, args.num_scheduler_steps,
|
||||
args.use_v2_block_manager, args.download_dir, args.load_format,
|
||||
args.disable_async_output_proc
|
||||
]
|
||||
|
||||
if args.async_engine:
|
||||
run_args.append(args.disable_frontend_multiprocessing)
|
||||
elapsed_time = uvloop.run(run_vllm_async(*run_args))
|
||||
else:
|
||||
elapsed_time = run_vllm(*run_args)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -353,10 +458,18 @@ if __name__ == "__main__":
|
||||
choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"],
|
||||
help='device type for vLLM execution, supporting CUDA, OpenVINO and '
|
||||
'CPU.')
|
||||
parser.add_argument(
|
||||
"--num-scheduler-steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Maximum number of forward steps per scheduler call.")
|
||||
parser.add_argument("--use-v2-block-manager",
|
||||
action='store_true',
|
||||
help="Enable block manager v2.")
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
help="Enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
help="enable chunked prefill for vLLM backend.")
|
||||
@ -405,6 +518,19 @@ if __name__ == "__main__":
|
||||
'section for more information.\n'
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
'quantization.\n')
|
||||
parser.add_argument(
|
||||
"--disable-async-output-proc",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable async output processor for vLLM backend.")
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
@ -13,7 +13,7 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
|
||||
@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
|
||||
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.mm(a, b)
|
||||
|
||||
|
||||
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
use_fast_accum=True)
|
||||
|
||||
|
||||
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
|
||||
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
||||
|
||||
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
|
||||
**kwargs) -> TMeasurement:
|
||||
min_run_time = 1
|
||||
|
||||
globals = {
|
||||
"a": a,
|
||||
"b": b,
|
||||
"scale_a": scale_a,
|
||||
"scale_b": scale_b,
|
||||
"out_dtype": out_dtype,
|
||||
"args": args,
|
||||
"kwargs": kwargs,
|
||||
"fn": fn,
|
||||
}
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
|
||||
stmt="fn(*args, **kwargs)",
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -110,19 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
|
||||
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
|
||||
|
||||
timers = []
|
||||
# pytorch impl
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16),
|
||||
b.to(dtype=torch.bfloat16)))
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(label, sub_label,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
|
||||
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
|
||||
# cutlass with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass with azp per-tensor
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj))
|
||||
|
||||
# cutlass with azp per-tensor + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, None, bias))
|
||||
|
||||
# cutlass with azp per-token
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp))
|
||||
|
||||
# cutlass with azp per-token + bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
|
||||
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
|
||||
torch.bfloat16, azp_adj, azp, bias))
|
||||
|
||||
return timers
|
||||
|
||||
@ -133,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
timers = []
|
||||
|
||||
# pytorch impl w. bf16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda")))
|
||||
|
||||
# pytorch impl: bf16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16))
|
||||
|
||||
# pytorch impl: bf16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
pytorch_fp8_impl_fast_accum,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# pytorch impl: fp16 output, without fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16))
|
||||
|
||||
# pytorch impl: fp16 output, with fp8 fast accum
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
pytorch_fp8_impl_fast_accum,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
|
||||
bench_fn(label,
|
||||
sub_label,
|
||||
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||
torch._scaled_mm,
|
||||
a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=torch.float16,
|
||||
use_fast_accum=True))
|
||||
|
||||
# cutlass impl: bf16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
|
||||
torch.bfloat16))
|
||||
# cutlass impl: fp16 output
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
|
||||
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
|
||||
|
||||
# cutlass impl: bf16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
|
||||
bias))
|
||||
|
||||
# cutlass impl: fp16 output, with bias
|
||||
timers.append(
|
||||
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
|
||||
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
|
||||
bias.to(dtype=torch.float16)))
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
@ -193,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):
|
||||
|
||||
def run(dtype: torch.dtype,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
|
||||
@ -209,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None):
|
||||
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
@ -244,7 +281,6 @@ def run_range_bench(args):
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
89
benchmarks/kernels/benchmark_layernorm.py
Normal file
89
benchmarks/kernels/benchmark_layernorm.py
Normal file
@ -0,0 +1,89 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
layer(x, residual)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the layernorm kernel.")
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--add-residual", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
add_residual=args.add_residual,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters)
|
372
benchmarks/kernels/benchmark_machete.py
Normal file
372
benchmarks/kernels/benchmark_machete.py
Normal file
@ -0,0 +1,372 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import pickle as pkl
|
||||
import time
|
||||
from typing import Callable, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, pack_rows, quantize_weights)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
|
||||
def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # make col major
|
||||
return ops.machete_prepack_B(w_q, wtype)
|
||||
|
||||
|
||||
def make_bench_tensors(
|
||||
atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
|
||||
k: int
|
||||
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
|
||||
torch.tensor]]]:
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||
# so we target total weight size > 2*50mb
|
||||
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=atype) * 5
|
||||
weights = [
|
||||
torch.randn((k, n), device="cuda", dtype=atype)
|
||||
for _ in range(num_weights)
|
||||
]
|
||||
quanitized_weights = [
|
||||
quantize_weights(w, wtype, group_size) for w in weights
|
||||
]
|
||||
|
||||
return a, quanitized_weights
|
||||
|
||||
|
||||
# impl
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(label: str, sub_label: str, description: str,
|
||||
fn: Callable) -> TMeasurement:
|
||||
|
||||
min_run_time = 1
|
||||
return TBenchmark.Timer(
|
||||
stmt="fn()",
|
||||
globals={
|
||||
"fn": fn
|
||||
},
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description=description,
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
|
||||
|
||||
def loop_over_weights(
|
||||
a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
|
||||
torch.tensor, torch.tensor]],
|
||||
fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
|
||||
None]):
|
||||
for w_ref, w_q, w_s, _ in weights:
|
||||
fn(a, w_ref, w_q, w_s)
|
||||
|
||||
|
||||
def bench(atype: torch.dtype,
|
||||
wtype: ScalarType,
|
||||
group_size: int,
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
label: str,
|
||||
sub_label: str,
|
||||
benchmark_marlinv1: bool = True,
|
||||
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
|
||||
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
|
||||
sub_label += f", L={len(weights)}"
|
||||
|
||||
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
|
||||
for w_ref, w_q, w_s, w_zp in weights]
|
||||
|
||||
timers = []
|
||||
# pytorch impl
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label, sub_label, "torch.matmul", lambda: loop_over_weights(
|
||||
a,
|
||||
weights,
|
||||
lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
|
||||
)))
|
||||
|
||||
if benchmark_marlinv1:
|
||||
w_ref = weights[0][0]
|
||||
|
||||
w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
|
||||
|
||||
def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
|
||||
w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
|
||||
return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
|
||||
wtype.size_bits)
|
||||
|
||||
def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
|
||||
return marlin_permute_scales(w_s, *w_ref.shape, group_size)
|
||||
|
||||
weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
|
||||
marlinv1_permute_scales(w_s), w_zp)
|
||||
for w_ref, w_q, w_s, w_zp in weights]
|
||||
|
||||
workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
# marlinv1
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label, sub_label, "marlin_orig", lambda: loop_over_weights(
|
||||
a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
|
||||
gptq_marlin_gemm(a,
|
||||
w_q,
|
||||
w_s,
|
||||
w_zp_empty,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
wtype,
|
||||
size_m=a.shape[0],
|
||||
size_n=w_ref.shape[1],
|
||||
size_k=w_ref.shape[0],
|
||||
is_k_full=True))))
|
||||
|
||||
# machete
|
||||
timers.append(
|
||||
bench_fn(
|
||||
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
|
||||
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
|
||||
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
|
||||
|
||||
if sweep_schedules:
|
||||
print("Finding best schedule for machete")
|
||||
best = None
|
||||
best_schedule = None
|
||||
schedules = ops.machete_supported_schedules(wtype)
|
||||
for schedule in reversed(schedules):
|
||||
|
||||
def run(a, _, w_q, w_s, schedule=schedule):
|
||||
ops.machete_gemm(a,
|
||||
w_q,
|
||||
wtype,
|
||||
w_s,
|
||||
b_group_size=group_size,
|
||||
schedule=schedule)
|
||||
|
||||
res = bench_fn(label, sub_label, "machete_best",
|
||||
lambda: loop_over_weights(a, weights_machete, run))
|
||||
|
||||
print(f" {res.median:5.5} ", schedule)
|
||||
if not best or res.median < best.median:
|
||||
best = res
|
||||
best_schedule = schedule
|
||||
print("Best schedule:", best_schedule)
|
||||
timers.append(best)
|
||||
|
||||
return timers
|
||||
|
||||
|
||||
# runner
|
||||
def print_timers(timers: Iterable[TMeasurement]):
|
||||
compare = TBenchmark.Compare(timers)
|
||||
compare.print()
|
||||
|
||||
|
||||
def run(dtype: torch.dtype, sweep_schedules: bool,
|
||||
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||
|
||||
results = []
|
||||
for m, k, n in MKNs:
|
||||
timers = bench(dtype,
|
||||
scalar_types.uint4b8,
|
||||
128,
|
||||
m,
|
||||
k,
|
||||
n,
|
||||
f"{dtype}-gemm",
|
||||
f"MKN=({m}x{k}x{n})",
|
||||
sweep_schedules=sweep_schedules)
|
||||
print_timers(timers)
|
||||
results.extend(timers)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# output makers
|
||||
def make_output(
|
||||
data: Iterable[TMeasurement],
|
||||
MKNs: Iterable[Tuple[int, int, int]],
|
||||
base_description: str,
|
||||
timestamp=None,
|
||||
):
|
||||
|
||||
print(f"== All Results {base_description} ====")
|
||||
print_timers(data)
|
||||
|
||||
# pickle all the results
|
||||
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(data, f)
|
||||
|
||||
|
||||
# argparse runners
|
||||
|
||||
|
||||
def run_square_bench(args):
|
||||
dim_sizes = list(
|
||||
range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_range_bench(args):
|
||||
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||
n = len(dim_sizes)
|
||||
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||
MKNs = list(zip(Ms, Ks, Ns))
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
|
||||
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||
|
||||
|
||||
def run_model_bench(args):
|
||||
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
|
||||
KNs = []
|
||||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||
KNs.append(KN)
|
||||
return KNs
|
||||
|
||||
model_bench_data = []
|
||||
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||
for model, tp_size in models_tps:
|
||||
Ms = args.batch_sizes
|
||||
KNs = model_shapes(model, tp_size)
|
||||
MKNs = []
|
||||
for m in Ms:
|
||||
for k, n in KNs:
|
||||
MKNs.append((m, k, n))
|
||||
|
||||
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||
model_bench_data.append(data)
|
||||
|
||||
# Print all results
|
||||
for data, model_tp in zip(model_bench_data, models_tps):
|
||||
model, tp_size = model_tp
|
||||
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||
print_timers(data)
|
||||
|
||||
timestamp = int(time.time())
|
||||
|
||||
all_data = []
|
||||
for d in model_bench_data:
|
||||
all_data.extend(d)
|
||||
# pickle all data
|
||||
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||
pkl.dump(all_data, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "bfloat16":
|
||||
return torch.bfloat16
|
||||
if dt == "float16":
|
||||
return torch.float16
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="""
|
||||
Benchmark Machete GEMM.
|
||||
|
||||
To run square GEMMs:
|
||||
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||
|
||||
To run constant N and K and sweep M:
|
||||
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||
|
||||
To run dimensions from a model:
|
||||
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||
|
||||
Output:
|
||||
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||
""", # noqa: E501
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=to_torch_dtype,
|
||||
required=True,
|
||||
help="Available options are ['bfloat16', 'float16']",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sweep-schedules",
|
||||
action="store_true",
|
||||
help="Run a sweep over all supported schedules",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||
|
||||
square_parser = subparsers.add_parser("square_bench")
|
||||
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
square_parser.set_defaults(func=run_square_bench)
|
||||
|
||||
range_parser = subparsers.add_parser("range_bench")
|
||||
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||
range_parser.set_defaults(func=run_range_bench)
|
||||
|
||||
model_parser = subparsers.add_parser("model_bench")
|
||||
model_parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=DEFAULT_MODELS,
|
||||
choices=WEIGHT_SHAPES.keys(),
|
||||
)
|
||||
model_parser.add_argument("--tp-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_TP_SIZES)
|
||||
model_parser.add_argument("--batch-sizes",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=DEFAULT_BATCH_SIZES)
|
||||
model_parser.set_defaults(func=run_model_bench)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
@ -7,16 +7,17 @@ from benchmark_shapes import WEIGHT_SHAPES
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
gptq_pack, gptq_quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
@ -27,13 +28,14 @@ K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
|
||||
size_m: int, size_k: int, size_n: int):
|
||||
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
|
||||
group_size, size_m, size_k, size_n))
|
||||
sub_label = ("{}, act={} k_full={}, q={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full,
|
||||
str(quant_type), group_size, size_m,
|
||||
size_k, size_n))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
@ -50,16 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
marlin_rand_perm,
|
||||
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||
) = marlin_quantize(b, quant_type, group_size, act_order)
|
||||
|
||||
# Marlin_24 quant
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
|
||||
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
|
||||
|
||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx,
|
||||
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
@ -73,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
|
||||
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
"num_bits": num_bits,
|
||||
"quant_type": quant_type,
|
||||
"group_size": group_size,
|
||||
"size_m": size_m,
|
||||
"size_n": size_n,
|
||||
@ -87,6 +92,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_zp": marlin_zp,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_rand_perm": marlin_rand_perm,
|
||||
@ -125,19 +131,29 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm",
|
||||
description="gptq_marlin_gemm_fp16",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -147,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -183,12 +199,13 @@ def main(args):
|
||||
) > 0 and is_k_full not in args.limit_k_full:
|
||||
continue
|
||||
|
||||
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
if len(args.limit_num_bits
|
||||
) > 0 and num_bits not in args.limit_num_bits:
|
||||
for quant_type in query_marlin_supported_quant_types(
|
||||
False):
|
||||
if len(args.limit_num_bits) > 0 and \
|
||||
quant_type.size_bits not in args.limit_num_bits:
|
||||
continue
|
||||
|
||||
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
if len(
|
||||
args.limit_group_size
|
||||
) > 0 and group_size not in args.limit_group_size:
|
||||
@ -202,8 +219,8 @@ def main(args):
|
||||
|
||||
for size_m in args.batch_sizes:
|
||||
bench_run(results, model, act_order, is_k_full,
|
||||
num_bits, group_size, size_m, size_k,
|
||||
size_n)
|
||||
quant_type, group_size, size_m,
|
||||
size_k, size_n)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
@ -30,19 +30,36 @@ def benchmark_config(
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
num_iters: int = 100,
|
||||
) -> float:
|
||||
init_dtype = torch.float16 if use_fp8 else dtype
|
||||
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
w1 = torch.randn(num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
dtype=init_dtype)
|
||||
w2 = torch.randn(num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
dtype=init_dtype)
|
||||
if use_int8_w8a16:
|
||||
w1 = torch.randint(-127,
|
||||
127, (
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
w2 = torch.randint(-127,
|
||||
127, (
|
||||
num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
),
|
||||
dtype=torch.int8)
|
||||
else:
|
||||
w1 = torch.randn(num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
dtype=init_dtype)
|
||||
w2 = torch.randn(num_experts,
|
||||
hidden_size,
|
||||
shard_intermediate_size // 2,
|
||||
dtype=init_dtype)
|
||||
gating_output = torch.randn(num_iters,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
@ -52,7 +69,11 @@ def benchmark_config(
|
||||
w2_scale = None
|
||||
a1_scale = None
|
||||
a2_scale = None
|
||||
if use_fp8:
|
||||
if use_int8_w8a16:
|
||||
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_fp8_w8a8:
|
||||
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
@ -76,7 +97,8 @@ def benchmark_config(
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
override_config=config,
|
||||
use_fp8=use_fp8,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
@ -155,11 +177,13 @@ class BenchmarkWorker:
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
torch.cuda.manual_seed_all(self.seed)
|
||||
|
||||
dtype_str = "float8" if use_fp8 else None
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
|
||||
@ -173,7 +197,8 @@ class BenchmarkWorker:
|
||||
key=lambda x: abs(x - num_tokens))]
|
||||
kernel_time = benchmark_config(config, num_tokens, num_experts,
|
||||
shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8)
|
||||
topk, dtype, use_fp8_w8a8,
|
||||
use_int8_w8a16)
|
||||
return config, kernel_time
|
||||
|
||||
def tune(
|
||||
@ -184,9 +209,10 @@ class BenchmarkWorker:
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
search_space: List[BenchmarkConfig],
|
||||
) -> BenchmarkConfig:
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
@ -198,7 +224,8 @@ class BenchmarkWorker:
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=10)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
}
|
||||
|
||||
|
||||
def save_configs(
|
||||
configs: Dict[int, BenchmarkConfig],
|
||||
num_experts: int,
|
||||
shard_intermediate_size: int,
|
||||
hidden_size: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
) -> None:
|
||||
dtype_str = "float8" if use_fp8 else None
|
||||
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||
dtype: torch.dtype, use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool) -> None:
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_fp8_w8a8=use_fp8_w8a8)
|
||||
|
||||
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||
# is the intermediate size after silu_and_mul.
|
||||
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
|
||||
dtype_str)
|
||||
|
||||
print(f"Writing best config to {filename}...")
|
||||
with open(filename, "w") as f:
|
||||
json.dump(configs, f, indent=4)
|
||||
@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.ffn_config.moe_top_k
|
||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "JambaForCausalLM":
|
||||
E = config.num_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
use_fp8 = args.dtype == "fp8"
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
|
||||
if args.batch_size is None:
|
||||
batch_sizes = [
|
||||
@ -294,21 +326,21 @@ def main(args: argparse.Namespace):
|
||||
start = time.time()
|
||||
configs = _distribute(
|
||||
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8, search_space)
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
|
||||
for batch_size in batch_sizes])
|
||||
best_configs = {
|
||||
M: sort_config(config)
|
||||
for M, config in zip(batch_sizes, configs)
|
||||
}
|
||||
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8)
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||
end = time.time()
|
||||
print(f"Tuning took {end - start:.2f} seconds")
|
||||
else:
|
||||
outputs = _distribute("benchmark",
|
||||
[(batch_size, E, shard_intermediate_size,
|
||||
hidden_size, topk, dtype, use_fp8)
|
||||
for batch_size in batch_sizes])
|
||||
outputs = _distribute(
|
||||
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
|
||||
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
|
||||
for batch_size in batch_sizes])
|
||||
|
||||
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||
print(f"Batch size: {batch_size}, config: {config}")
|
||||
@ -323,7 +355,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tp-size", "-tp", type=int, default=2)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=["auto", "fp8_w8a8", "int8_w8a16"],
|
||||
default="auto")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
|
@ -175,7 +175,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
|
103
benchmarks/kernels/benchmark_quant.py
Normal file
103
benchmarks/kernels/benchmark_quant.py
Normal file
@ -0,0 +1,103 @@
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(num_tokens: int,
|
||||
hidden_size: int,
|
||||
static_scale: bool,
|
||||
quant_dtype: torch.dtype,
|
||||
dtype: torch.dtype,
|
||||
seed: int = 0,
|
||||
do_profile: bool = False,
|
||||
num_warmup_iters: int = 5,
|
||||
num_iters: int = 100) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
|
||||
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
if quant_dtype == torch.int8:
|
||||
ops.scaled_int8_quant(x, scale)
|
||||
else:
|
||||
ops.scaled_fp8_quant(x, scale)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def to_torch_dtype(dt):
|
||||
if dt == "int8":
|
||||
return torch.int8
|
||||
if dt == "fp8":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported dtype: {dt}")
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the quantization (fp8 or int8) kernel.")
|
||||
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||
parser.add_argument("--static-scale", action="store_true")
|
||||
parser.add_argument("--quant-dtype",
|
||||
type=str,
|
||||
choices=["fp8", "int8"],
|
||||
default="int8")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||
parser.add_argument("--num-iters",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations. "
|
||||
"If --profile is set, this number is ignored")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(num_tokens=args.num_tokens,
|
||||
hidden_size=args.hidden_size,
|
||||
static_scale=args.static_scale,
|
||||
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
num_warmup_iters=args.num_warmup_iters,
|
||||
num_iters=args.num_iters)
|
@ -94,7 +94,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--num-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||
parser.add_argument("--dtype",
|
||||
|
64
benchmarks/kernels/graph_machete_bench.py
Normal file
64
benchmarks/kernels/graph_machete_bench.py
Normal file
@ -0,0 +1,64 @@
|
||||
import math
|
||||
import pickle
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
parser.add_argument('filename', type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.filename, 'rb') as f:
|
||||
data: List[TMeasurement] = pickle.load(f)
|
||||
|
||||
results = defaultdict(lambda: list())
|
||||
for v in data:
|
||||
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
|
||||
if result is not None:
|
||||
KN = result.group(1)
|
||||
else:
|
||||
raise Exception("MKN not found")
|
||||
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
|
||||
if result is not None:
|
||||
M = result.group(1)
|
||||
else:
|
||||
raise Exception("MKN not found")
|
||||
|
||||
kernel = v.task_spec.description
|
||||
results[KN].append({
|
||||
"kernel": kernel,
|
||||
"batch_size": M,
|
||||
"median": v.median
|
||||
})
|
||||
|
||||
rows = int(math.ceil(len(results) / 2))
|
||||
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
|
||||
axs = axs.flatten()
|
||||
axs_idx = 0
|
||||
for shape, data in results.items():
|
||||
plt.sca(axs[axs_idx])
|
||||
df = pd.DataFrame(data)
|
||||
sns.lineplot(data=df,
|
||||
x="batch_size",
|
||||
y="median",
|
||||
hue="kernel",
|
||||
style="kernel",
|
||||
markers=True,
|
||||
dashes=False,
|
||||
palette="Dark2")
|
||||
plt.title(f"Shape: {shape}")
|
||||
plt.ylabel("time (median, s)")
|
||||
axs_idx += 1
|
||||
plt.tight_layout()
|
||||
plt.savefig("graph_machete_bench.pdf")
|
43
benchmarks/kernels/weight_shapes.py
Normal file
43
benchmarks/kernels/weight_shapes.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Weight Shapes are in the format
|
||||
# ([K, N], TP_SPLIT_DIM)
|
||||
# Example:
|
||||
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||
# - TP1 : K = 14336, N = 4096
|
||||
# - TP2 : K = 7168, N = 4096
|
||||
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||
# - TP1 : K = 4096, N = 6144
|
||||
# - TP4 : K = 4096, N = 1536
|
||||
|
||||
# TP1 shapes
|
||||
WEIGHT_SHAPES = {
|
||||
"mistralai/Mistral-7B-v0.1": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-7b-hf": [
|
||||
([4096, 12288], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 22016], 1),
|
||||
([11008, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-3-8b": [
|
||||
([4096, 6144], 1),
|
||||
([4096, 4096], 0),
|
||||
([4096, 28672], 1),
|
||||
([14336, 4096], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-13b-hf": [
|
||||
([5120, 15360], 1),
|
||||
([5120, 5120], 0),
|
||||
([5120, 27648], 1),
|
||||
([13824, 5120], 0),
|
||||
],
|
||||
"meta-llama/Llama-2-70b-hf": [
|
||||
([8192, 10240], 1),
|
||||
([8192, 8192], 0),
|
||||
([8192, 57344], 1),
|
||||
([28672, 8192], 0),
|
||||
],
|
||||
}
|
@ -6,7 +6,7 @@ TOKENS=$2
|
||||
|
||||
docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
|
||||
-v $PWD/data:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||
ghcr.io/huggingface/text-generation-inference:2.2.0 \
|
||||
--model-id $MODEL \
|
||||
--sharded false \
|
||||
--max-input-length 1024 \
|
||||
|
@ -83,6 +83,8 @@ endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
|
||||
list(APPEND LIBS "numa")
|
||||
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
@ -95,6 +97,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/attention.cpp"
|
||||
"csrc/cpu/cache.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
@ -104,11 +107,11 @@ define_gpu_extension_target(
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
LIBRARIES ${LIBS}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
|
||||
add_custom_target(default)
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
|
@ -181,7 +181,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
||||
#
|
||||
# The torch cmake setup hardcodes the detected architecture flags in
|
||||
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
|
||||
# can't modified on a per-target basis, e.g. for the `punica` extension.
|
||||
# can't modified on a per-target basis.
|
||||
# So, all the `-gencode` flags need to be extracted and removed from
|
||||
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
|
||||
# Since it's not possible to use `target_compiler_options` for adding target
|
||||
|
@ -65,6 +65,9 @@ DEFAULT_CONDA_PATTERNS = {
|
||||
"optree",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
@ -77,6 +80,9 @@ DEFAULT_PIP_PATTERNS = {
|
||||
"onnx",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
"nvidia",
|
||||
"pynvml",
|
||||
}
|
||||
|
||||
|
||||
@ -263,8 +269,9 @@ def get_neuron_sdk_version(run_lambda):
|
||||
def get_vllm_version():
|
||||
try:
|
||||
import vllm
|
||||
return vllm.__version__
|
||||
except ImportError:
|
||||
return vllm.__version__ + "@" + vllm.__commit__
|
||||
except Exception:
|
||||
# old version of vllm does not have __commit__
|
||||
return 'N/A'
|
||||
|
||||
|
||||
|
@ -706,7 +706,7 @@ void paged_attention_v1_launcher(
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
@ -751,6 +751,9 @@ void paged_attention_v1_launcher(
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V1(112);
|
||||
break;
|
||||
case 120:
|
||||
LAUNCH_PAGED_ATTENTION_V1(120);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||
break;
|
||||
@ -862,7 +865,7 @@ void paged_attention_v2_launcher(
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
@ -912,6 +915,9 @@ void paged_attention_v2_launcher(
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V2(112);
|
||||
break;
|
||||
case 120:
|
||||
LAUNCH_PAGED_ATTENTION_V2(120);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||
break;
|
||||
|
@ -34,7 +34,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < N; ++ii) {
|
||||
qk_vec = fma(q[ii], k[ii], qk_vec);
|
||||
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
|
||||
}
|
||||
|
||||
// Finalize the reduction across lanes.
|
||||
|
@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
||||
#else
|
||||
return __bfloat1622float2(val);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
||||
@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
||||
#else
|
||||
return __bfloat162bfloat162(val);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// Vector addition.
|
||||
@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#else
|
||||
return __hadd2(a, b);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
||||
@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#else
|
||||
return __hmul(a, b);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#else
|
||||
return __hmul2(a, b);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||
#else
|
||||
return __hfma2(a, b, c);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||
@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||
#else
|
||||
return __hfma2(bf162bf162(a), b, c);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
||||
|
@ -25,7 +25,8 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
const std::string& kv_cache_dtype,
|
||||
const double k_scale, const double v_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
|
@ -203,17 +203,18 @@ __global__ void reshape_and_cache_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size) {
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float k_scale, const float v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
key_cache[tgt_key_value_idx] = tgt_key;
|
||||
value_cache[tgt_key_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
value_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
@ -278,40 +288,45 @@ void reshape_and_cache(
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype) {
|
||||
// FIXME: only support auto datatype, does not support fp8
|
||||
if (kv_cache_dtype != "auto") {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = k_cache.size(1);
|
||||
int block_size = key_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = k_cache.stride(0);
|
||||
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
|
||||
int block_stride = key_cache.stride(0);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||
value_stride, num_heads, head_size, block_size);
|
||||
});
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
548
csrc/core/scalar_type.hpp
Normal file
548
csrc/core/scalar_type.hpp
Normal file
@ -0,0 +1,548 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
|
||||
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
|
||||
// can be used as a argument for custom operators, helping to simplify these
|
||||
// interfaces.
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : uint8_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
|
||||
int32_t bias, bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr){};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits, false, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(uint8_t exponent,
|
||||
uint8_t mantissa) {
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
|
||||
bool finite_values_only,
|
||||
NanRepr nan_repr) {
|
||||
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions");
|
||||
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
|
||||
nan_repr);
|
||||
}
|
||||
|
||||
uint8_t const exponent; // size of the exponent field (0 for integer types)
|
||||
uint8_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
int32_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
using Id = int64_t;
|
||||
|
||||
private:
|
||||
// Field size in id
|
||||
template <typename T_>
|
||||
static constexpr size_t member_id_field_width() {
|
||||
using T = std::decay_t<T_>;
|
||||
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init, typename Member, typename... Rest>
|
||||
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
|
||||
Rest... rest) {
|
||||
auto new_val = f(val, member);
|
||||
if constexpr (sizeof...(rest) > 0) {
|
||||
return reduce_members_helper(f, new_val, rest...);
|
||||
} else {
|
||||
return new_val;
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
constexpr auto reduce_members(Fn f, Init init) const {
|
||||
// Should be in constructor order for `from_id`
|
||||
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
|
||||
finite_values_only, nan_repr);
|
||||
};
|
||||
|
||||
template <typename Fn, typename Init>
|
||||
static constexpr auto reduce_member_types(Fn f, Init init) {
|
||||
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
|
||||
return dummy_type.reduce_members(f, init);
|
||||
};
|
||||
|
||||
static constexpr auto id_size_bits() {
|
||||
return reduce_member_types(
|
||||
[](int acc, auto member) -> int {
|
||||
return acc + member_id_field_width<decltype(member)>();
|
||||
},
|
||||
0);
|
||||
}
|
||||
|
||||
public:
|
||||
// unique id for this scalar type that can be computed at compile time for
|
||||
// c++17 template specialization this is not needed once we migrate to
|
||||
// c++20 and can pass literal classes as template parameters
|
||||
constexpr Id id() const {
|
||||
static_assert(id_size_bits() <= sizeof(Id) * 8,
|
||||
"ScalarType id is too large to be stored");
|
||||
|
||||
auto or_and_advance = [](std::pair<Id, uint32_t> result,
|
||||
auto member) -> std::pair<Id, uint32_t> {
|
||||
auto [id, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<decltype(member)>();
|
||||
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
|
||||
<< bit_offset,
|
||||
bit_offset + bits};
|
||||
};
|
||||
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
|
||||
}
|
||||
|
||||
// create a ScalarType from an id, for c++17 template specialization,
|
||||
// this is not needed once we migrate to c++20 and can pass literal
|
||||
// classes as template parameters
|
||||
static constexpr ScalarType from_id(Id id) {
|
||||
auto extract_and_advance = [id](auto result, auto member) {
|
||||
using T = decltype(member);
|
||||
auto [tuple, bit_offset] = result;
|
||||
auto constexpr bits = member_id_field_width<T>();
|
||||
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
|
||||
((uint64_t(1) << bits) - 1));
|
||||
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
|
||||
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
|
||||
};
|
||||
|
||||
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
|
||||
std::pair<std::tuple<>, int>{});
|
||||
return std::apply([](auto... args) { return ScalarType(args...); },
|
||||
tuple_args);
|
||||
}
|
||||
|
||||
constexpr int64_t size_bits() const {
|
||||
return mantissa + exponent + is_signed();
|
||||
}
|
||||
constexpr bool is_signed() const { return signed_; }
|
||||
constexpr bool is_integer() const { return exponent == 0; }
|
||||
constexpr bool is_floating_point() const { return exponent > 0; }
|
||||
constexpr bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false &&
|
||||
nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
constexpr bool has_nans() const {
|
||||
return is_floating_point() && nan_repr != NAN_NONE;
|
||||
}
|
||||
constexpr bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
constexpr bool has_bias() const { return bias != 0; }
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
TORCH_CHECK(exponent < 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double =
|
||||
max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw =
|
||||
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
||||
"Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
TORCH_CHECK(is_signed(),
|
||||
"We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
TORCH_CHECK(!is_signed() || size_bits() <= 64,
|
||||
"Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> max() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
constexpr std::variant<int64_t, double> min() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
||||
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||
bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only &&
|
||||
nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
|
||||
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
|
||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
||||
// constexpr destructor)
|
||||
// See also:
|
||||
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
|
||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
public:
|
||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
||||
bool _signed)
|
||||
: ScalarType(exponent, mantissa, bias, _signed){};
|
||||
|
||||
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
|
||||
|
||||
using Base = ScalarType;
|
||||
using Self = ScalarTypeTorch;
|
||||
using SelfPtr = c10::intrusive_ptr<Self>;
|
||||
|
||||
static void check_size_bits(int64_t size_bits, bool signed_) {
|
||||
TORCH_CHECK(
|
||||
size_bits <=
|
||||
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||
"size_bits bit width is too large to be represented");
|
||||
}
|
||||
|
||||
static void check_bias(int64_t bias) {
|
||||
using Bias = decltype(std::declval<Self>().bias);
|
||||
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
|
||||
bias >= std::numeric_limits<Bias>::min(),
|
||||
"bias too large or small to be represented");
|
||||
}
|
||||
|
||||
static void check_exponent(int64_t exponent) {
|
||||
TORCH_CHECK(
|
||||
exponent <=
|
||||
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
|
||||
"exponent bit width is too large to be represented");
|
||||
}
|
||||
|
||||
static void check_mantissa(int64_t mantissa) {
|
||||
TORCH_CHECK(
|
||||
mantissa <=
|
||||
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
|
||||
"mantissa bit width is too large to be represented");
|
||||
}
|
||||
|
||||
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
check_size_bits(size_bits, true);
|
||||
check_bias(bias.value_or(0));
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::int_(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
check_size_bits(size_bits, true);
|
||||
check_bias(bias.value_or(0));
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::uint(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
||||
check_mantissa(mantissa);
|
||||
check_exponent(exponent);
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::float_IEEE754(exponent, mantissa));
|
||||
}
|
||||
|
||||
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
||||
bool finite_values_only, int64_t nan_repr) {
|
||||
check_mantissa(mantissa);
|
||||
check_exponent(exponent);
|
||||
return c10::make_intrusive<Self>(ScalarType::float_(
|
||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||
}
|
||||
|
||||
// This needs to be implemented and throw a TypeError in order for
|
||||
// PyTorch's opcheck to work on ops that use ScalarTypes.
|
||||
int64_t len() const {
|
||||
throw c10::TypeError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)},
|
||||
"__len__ not implemented");
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Serialize a ScalarType into a tuple of pairs. Where each pair
|
||||
// is a (fieldname, value).
|
||||
// For simplicity, we are just going to convert to a ScalarTypeId.
|
||||
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
|
||||
return {{"ScalarType", id()}};
|
||||
}
|
||||
|
||||
// Deserialize a scalar type that has been serialized by obj_flatten,
|
||||
// ostensibly from a tuple of (member name, value) pairs, but in reality
|
||||
// just a ScalarTypeId.
|
||||
static SelfPtr obj_unflatten(
|
||||
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
|
||||
return c10::make_intrusive<Self>(
|
||||
from_id(std::get<1>(std::get<0>(flat_type))));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||
std::string const& name, T Base::*field) {
|
||||
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
|
||||
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
||||
return (self.get()->*field)();
|
||||
} else {
|
||||
return self.get()->*field;
|
||||
}
|
||||
};
|
||||
|
||||
auto getter_func = [field = std::move(field),
|
||||
getter_func_helper = std::move(getter_func_helper)](
|
||||
SelfPtr const& self) {
|
||||
auto val = getter_func_helper(self);
|
||||
// upconvert uint8_t, int32_t etc. to int64_t for python
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
return static_cast<int64_t>(val);
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
cls.def_property(name, getter_func);
|
||||
}
|
||||
|
||||
template <typename MemberFunc, typename Cls>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
MemberFunc Cls::*member) {
|
||||
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
|
||||
return (self.get()->*member)();
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
Func func) {
|
||||
cls.def(name, func);
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_static_function(torch::class_<Self>& cls,
|
||||
const std::string& name, Func func) {
|
||||
cls.def_static(name, func);
|
||||
}
|
||||
|
||||
static void bind_class(torch::Library& lib) {
|
||||
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
|
||||
.def(torch::init<int64_t, int64_t, int64_t, bool>());
|
||||
|
||||
// Bind Properties
|
||||
bind_readonly_property(cls, "mantissa", &Base::mantissa);
|
||||
bind_readonly_property(cls, "exponent", &Base::exponent);
|
||||
bind_readonly_property(cls, "bias", &Base::bias);
|
||||
bind_readonly_property(cls, "signed", &Base::is_signed);
|
||||
bind_readonly_property(cls, "size_bits", &Base::size_bits);
|
||||
|
||||
// Bind member functions
|
||||
bind_function(cls, "is_signed", &Base::is_signed);
|
||||
bind_function(cls, "is_integer", &Base::is_integer);
|
||||
bind_function(cls, "is_floating_point", &Base::is_floating_point);
|
||||
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
|
||||
bind_function(cls, "has_nans", &Base::has_nans);
|
||||
bind_function(cls, "has_infs", &Base::has_infs);
|
||||
bind_function(cls, "has_bias", &Base::has_bias);
|
||||
|
||||
bind_function(cls, "max", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->max());
|
||||
});
|
||||
bind_function(cls, "min", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->min());
|
||||
});
|
||||
|
||||
bind_function(cls, "__len__", &ScalarTypeTorch::len);
|
||||
bind_function(cls, "__str__", &Base::str);
|
||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
||||
return *self == *other;
|
||||
});
|
||||
bind_function(cls, "__repr__", [](SelfPtr const& self) {
|
||||
return "ScalarType." + self.get()->str();
|
||||
});
|
||||
|
||||
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
|
||||
bind_static_function(cls, "__obj_unflatten__",
|
||||
&ScalarTypeTorch::obj_unflatten);
|
||||
|
||||
// Bind static functions (convenience constructors)
|
||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
||||
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
|
||||
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeId = int64_t;
|
||||
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE3M2f =
|
||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// Fixed width style names, generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
static inline constexpr auto kInt4 = kS4;
|
||||
static inline constexpr auto kUint4 = kU4;
|
||||
static inline constexpr auto kUint4b8 = kU4B8;
|
||||
static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
static inline constexpr auto kHalf = kFE5M10;
|
||||
static inline constexpr auto kFloat16 = kHalf;
|
||||
static inline constexpr auto kBFloat16 = kFE8M7;
|
||||
|
||||
static inline constexpr auto kFloat16Id = kFloat16.id();
|
||||
}; // namespace vllm
|
16
csrc/core/torch_bindings.cpp
Normal file
16
csrc/core/torch_bindings.cpp
Normal file
@ -0,0 +1,16 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "scalar_type.hpp"
|
||||
#include "registration.h"
|
||||
|
||||
// Note the CORE exstension will be built for (almost) all hardware targets so
|
||||
// new additions must account for this. (currently not built for TPU and Neuron)
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
|
||||
// ScalarType, a custom class for representing data types that supports
|
||||
// quantized types, declared here so it can be used when creating interfaces
|
||||
// for custom ops.
|
||||
vllm::ScalarTypeTorch::bind_class(lib);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
@ -1,9 +1,11 @@
|
||||
#include "cache.h"
|
||||
#include "ops.h"
|
||||
#include "registration.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
void init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -107,4 +109,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
// CPU utils
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
65
csrc/cpu/utils.cpp
Normal file
65
csrc/cpu/utils.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
#include <numa.h>
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <sched.h>
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
void init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
std::vector<int> omp_cpu_ids;
|
||||
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
||||
|
||||
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
|
||||
|
||||
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
|
||||
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
|
||||
int i = 0;
|
||||
while (group_mask) {
|
||||
if (group_mask & 1) {
|
||||
omp_cpu_ids.emplace_back(offset + i);
|
||||
}
|
||||
++i;
|
||||
group_mask >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Memory node binding
|
||||
if (numa_available() != -1) {
|
||||
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
|
||||
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
|
||||
bitmask* src_mask = numa_get_membind();
|
||||
|
||||
int pid = getpid();
|
||||
|
||||
// move all existing pages to the specified numa node.
|
||||
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
|
||||
int page_num = numa_migrate_pages(pid, src_mask, mask);
|
||||
if (page_num == -1) {
|
||||
TORCH_CHECK(false,
|
||||
"numa_migrate_pages failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
// restrict memory allocation node.
|
||||
numa_set_membind(mask);
|
||||
numa_set_strict(1);
|
||||
}
|
||||
|
||||
// OMP threads binding
|
||||
omp_set_num_threads((int)omp_cpu_ids.size());
|
||||
torch::set_num_threads((int)omp_cpu_ids.size());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
||||
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
|
||||
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
|
||||
CPU_ZERO_S(size, mask);
|
||||
CPU_SET_S(omp_cpu_ids[i], size, mask);
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), mask);
|
||||
CPU_FREE(mask);
|
||||
}
|
||||
|
||||
numa_free_nodemask(omp_cpu_mask);
|
||||
}
|
@ -1,5 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
||||
#define DEVICE_INLINE __forceinline__ __device__
|
||||
#define HOST_INLINE __forceinline__ __host__
|
||||
#else
|
||||
#define HOST_DEVICE_INLINE inline
|
||||
#define DEVICE_INLINE inline
|
||||
#define HOST_INLINE inline
|
||||
#endif
|
||||
|
||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
||||
|
||||
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
||||
|
68
csrc/cutlass_extensions/cute_utils.cuh
Normal file
68
csrc/cutlass_extensions/cute_utils.cuh
Normal file
@ -0,0 +1,68 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <torch/all.h>
|
||||
namespace cute {
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// layout utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Permute layout based on indices, example:
|
||||
// permute_layout<1, 0>(layout) will swap the two dimensions
|
||||
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
|
||||
template <size_t... I, typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
|
||||
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
|
||||
return cute::make_layout(cute::get<I>(l)...);
|
||||
}
|
||||
|
||||
// is the layout f(x) = x
|
||||
template <typename Layout>
|
||||
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
|
||||
if constexpr (std::is_same_v<Layout, void>)
|
||||
return true;
|
||||
else {
|
||||
constexpr auto coalesced_layout = coalesce(Layout{});
|
||||
if constexpr (rank(coalesced_layout) == 1 &&
|
||||
stride<0>(coalesced_layout) == 1) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Pointer utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <class PointerType>
|
||||
static constexpr auto get_logical_ptr(PointerType* ptr) {
|
||||
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
|
||||
return cute::subbyte_iterator<PointerType>(ptr);
|
||||
} else {
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
// Misc utils
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename Elements>
|
||||
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
|
||||
constexpr auto bits = sizeof_bits_v<T> * Elements{};
|
||||
if constexpr (bits % 128 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<128>{};
|
||||
} else if constexpr (bits % 64 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<64>{};
|
||||
} else if constexpr (bits % 32 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<32>{};
|
||||
} else if constexpr (bits % 16 == 0) {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<16>{};
|
||||
} else {
|
||||
return AutoVectorizingCopyWithAssumedAlignment<8>{};
|
||||
}
|
||||
}
|
||||
|
||||
}; // namespace cute
|
154
csrc/cutlass_extensions/torch_utils.hpp
Normal file
154
csrc/cutlass_extensions/torch_utils.hpp
Normal file
@ -0,0 +1,154 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cute/layout.hpp"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
|
||||
namespace cute {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class T, class F, class G, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
|
||||
seq<I...>) {
|
||||
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
|
||||
}
|
||||
|
||||
template <class F, int... I>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
|
||||
return make_shape(f(I)...);
|
||||
}
|
||||
|
||||
}; // namespace detail
|
||||
|
||||
template <class T, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
|
||||
if constexpr (cute::is_tuple<T>::value) {
|
||||
return detail::tapply_with_idx(
|
||||
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
|
||||
tuple_seq<T>{});
|
||||
} else {
|
||||
return f(t);
|
||||
}
|
||||
|
||||
CUTE_GCC_UNREACHABLE;
|
||||
}
|
||||
|
||||
// calls: make_shape(f(0), f(1), ..., f(N-1))
|
||||
template <int N, class F>
|
||||
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
|
||||
return detail::make_shape_from_idx(f, make_seq<N>{});
|
||||
}
|
||||
|
||||
}; // namespace cute
|
||||
|
||||
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
|
||||
// shape of the passed in tensor and the strides are of type `Stride` and
|
||||
// contain the strides of the passed in tensor, checking that any static strides
|
||||
// in `Stride{}` match the strides of the passed in tensor.
|
||||
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
|
||||
// strides are set to be 0 or 1.
|
||||
template <typename Stride>
|
||||
static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
|
||||
auto stride = cute::transform_with_idx(
|
||||
Stride{}, [&](auto const& stride_ele, auto const& idx) {
|
||||
using StrideEle = std::decay_t<decltype(stride_ele)>;
|
||||
|
||||
if (idx < tensor.dim()) {
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
|
||||
name, ".stride(", idx, ") to be ", StrideEle::value);
|
||||
return StrideEle{};
|
||||
} else {
|
||||
return tensor.stride(idx);
|
||||
}
|
||||
} else {
|
||||
// Extra strides are assumed to be 0 or 1
|
||||
if constexpr (cute::is_static_v<StrideEle>) {
|
||||
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
|
||||
}
|
||||
return StrideEle{};
|
||||
}
|
||||
});
|
||||
|
||||
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
|
||||
if (idx < tensor.dim())
|
||||
return tensor.size(idx);
|
||||
else
|
||||
return int64_t(1);
|
||||
});
|
||||
|
||||
return make_layout(shape, stride);
|
||||
}
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
c10::optional<torch::Tensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
if (tensor) {
|
||||
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
|
||||
} else {
|
||||
return std::optional<Layout>{};
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Torch Type to Cutlass Type (equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
template <typename T>
|
||||
struct equivalent_cutlass_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<c10::Half> {
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_cutlass_type<c10::BFloat16> {
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
|
||||
//
|
||||
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
|
||||
//
|
||||
|
||||
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
|
||||
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
|
||||
template <typename T>
|
||||
struct equivalent_scalar_type {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::half_t> {
|
||||
using type = c10::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct equivalent_scalar_type<cutlass::bfloat16_t> {
|
||||
using type = c10::BFloat16;
|
||||
};
|
||||
|
||||
// get equivalent c10::ScalarType tag from compile time type
|
||||
template <typename T>
|
||||
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
|
||||
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
|
43
csrc/cutlass_extensions/vllm_collective_builder.cuh
Normal file
43
csrc/cutlass_extensions/vllm_collective_builder.cuh
Normal file
@ -0,0 +1,43 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
|
||||
namespace cutlass::gemm::collective {
|
||||
using namespace cute;
|
||||
|
||||
//
|
||||
// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
|
||||
// for custom kernel tags, allowing you to build custom collectives. Without
|
||||
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
|
||||
// will resort to using the standard cutlass collective builder.
|
||||
//
|
||||
|
||||
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
|
||||
// collective
|
||||
struct CutlassKernelTag {};
|
||||
|
||||
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
|
||||
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
|
||||
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
|
||||
class ClusterShape_MNK, class StageCountType,
|
||||
class KernelScheduleType, class Enable = void>
|
||||
struct VLLMCollectiveBuilder {
|
||||
static_assert(sizeof(ElementA) == 0,
|
||||
"Could not build a collective for given parameters.");
|
||||
};
|
||||
|
||||
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
|
||||
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
|
||||
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
|
||||
class StageCountType, class KernelScheduleType>
|
||||
struct VLLMCollectiveBuilder<
|
||||
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
|
||||
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType> {
|
||||
using CollectiveOp = typename CollectiveBuilder<
|
||||
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
|
||||
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
|
||||
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
50
csrc/cutlass_extensions/vllm_custom_types.cuh
Normal file
50
csrc/cutlass_extensions/vllm_custom_types.cuh
Normal file
@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/integer_subbyte.h"
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed = false>
|
||||
struct vllm_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
|
||||
using Base = integer_subbyte<Bits, Signed>;
|
||||
|
||||
using Storage = typename Base::Storage;
|
||||
using xint_t = typename Base::xint_t;
|
||||
|
||||
using Base::bits_mask_;
|
||||
using Base::sign_mask_;
|
||||
using Base::storage;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// No operation
|
||||
vllm_biased_integer_subbyte() = default;
|
||||
|
||||
/// Conversion from integer type
|
||||
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value)
|
||||
: Base(value) {}
|
||||
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value)
|
||||
: Base(value) {}
|
||||
};
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// "GPTQ" types, i.e. symmetric quantization
|
||||
using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8
|
||||
using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int Bits, int Bias, bool Signed>
|
||||
struct sizeof_bits<vllm_biased_integer_subbyte<Bits, Bias, Signed>> {
|
||||
static constexpr int value = Bits;
|
||||
};
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
49
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
Normal file
49
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
Normal file
@ -0,0 +1,49 @@
|
||||
import enum
|
||||
from typing import Dict, Union
|
||||
|
||||
from cutlass_library import *
|
||||
|
||||
#
|
||||
# Extend cutlass library with custom types, and missing values
|
||||
#
|
||||
|
||||
|
||||
class VLLMDataType(enum.Enum):
|
||||
u4b8 = enum_auto()
|
||||
u8b128 = enum_auto()
|
||||
|
||||
|
||||
class MixedInputKernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedMixedInput = enum_auto()
|
||||
TmaWarpSpecializedPingpongMixedInput = enum_auto()
|
||||
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
|
||||
|
||||
|
||||
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
**DataTypeNames, # type: ignore
|
||||
**{
|
||||
VLLMDataType.u4b8: "u4b8",
|
||||
VLLMDataType.u8b128: "u8b128",
|
||||
}
|
||||
}
|
||||
|
||||
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
**DataTypeTag, # type: ignore
|
||||
**{
|
||||
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
||||
}
|
||||
}
|
||||
|
||||
VLLMKernelScheduleTag: Dict[Union[
|
||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
|
||||
}
|
||||
}
|
795
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Normal file
795
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Normal file
@ -0,0 +1,795 @@
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
#include "cutlass_extensions/vllm_custom_types.cuh"
|
||||
#include "cutlass_extensions/cute_utils.cuh"
|
||||
|
||||
// this file extends:
|
||||
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
|
||||
// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t
|
||||
// as well as adds interleaved numeric array converters for specific types.
|
||||
// (interleaved numeric array converters can be more efficient for subbyte
|
||||
// types)
|
||||
|
||||
namespace cutlass {
|
||||
|
||||
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
|
||||
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
|
||||
// make subbyte converts more efficient by allowing for efficient extraction
|
||||
// of subbyte elements from a 32bit register.
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
|
||||
class Enable = void>
|
||||
struct InterleavedNumericArrayConverter {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
CUTE_INVALID_CONTROL_PATH(
|
||||
"InterleavedNumericArrayConverter not implemented\n");
|
||||
return {};
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
template <typename IlvBlkLayout, typename T, typename S, int N,
|
||||
FloatRoundStyle Round>
|
||||
struct InterleavedNumericArrayConverter<
|
||||
IlvBlkLayout, T, S, N, Round,
|
||||
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
|
||||
using Converter = NumericArrayConverter<T, S, N, Round>;
|
||||
|
||||
using result_type = typename Converter::result_type;
|
||||
using source_type = typename Converter::source_type;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return Converter::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// TODO (LucasWilkinson): Implement
|
||||
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
|
||||
|
||||
// ....
|
||||
|
||||
template <typename RegConvert32bit, typename T, typename S, int N>
|
||||
struct ArrayConverterPacked32Bit {
|
||||
using result_type = Array<T, N>;
|
||||
using source_type = Array<S, N>;
|
||||
|
||||
using result_packed_8_t = Array<T, 8>;
|
||||
using result_packed_4_t = Array<T, 4>;
|
||||
using result_packed_2_t = Array<T, 2>;
|
||||
using src_packed_8_t = Array<S, 8>;
|
||||
using src_packed_4_t = Array<S, 4>;
|
||||
using src_packed_2_t = Array<S, 2>;
|
||||
|
||||
static_assert(N % 2 == 0, "N must be a multiple of 2");
|
||||
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
|
||||
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
|
||||
static constexpr auto src_elems_per_32bit_reg =
|
||||
32 / cutlass::sizeof_bits_v<S>;
|
||||
|
||||
// Maybe not Valid. ScalarConverter will not actually work unless
|
||||
// NumericConverter<T, S, Round> is implemented. However it won't be used
|
||||
// anyways since we assert N % 2 == 0, just here for compliance with
|
||||
// VectorizedConverter.
|
||||
using ScalarConverter = NumericConverter<T, S>;
|
||||
|
||||
template <typename PackedSrc>
|
||||
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
|
||||
if constexpr (sizeof(PackedSrc) == 1) {
|
||||
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
|
||||
} else if constexpr (sizeof(PackedSrc) == 2) {
|
||||
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
|
||||
} else {
|
||||
static_assert(sizeof(PackedSrc) == 4);
|
||||
return reinterpret_cast<const uint32_t&>(source);
|
||||
}
|
||||
}
|
||||
|
||||
// The core converter uses bit tricks to construct a known FP16 number, then
|
||||
// does a subtraction in FP16 for the final result.
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
|
||||
static_assert(PackedResultType::kElements == 2 ||
|
||||
PackedResultType::kElements == 4 ||
|
||||
PackedResultType::kElements == 8,
|
||||
"Invalid PackedResultType must be 2, 4 or 8.");
|
||||
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
|
||||
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
|
||||
|
||||
return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
ArrayConverterPacked32Bit<RegConvert32bit,
|
||||
typename result_type::Element,
|
||||
typename source_type::Element, N>;
|
||||
|
||||
if constexpr (src_elems_per_32bit_reg >= 8) {
|
||||
detail::VectorizedConverter::convert<
|
||||
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
|
||||
} else if constexpr (src_elems_per_32bit_reg >= 4) {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
} else {
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<vllm_uint4b8_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
|
||||
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
|
||||
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
|
||||
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
|
||||
// We use inline asm instead of __byte_perm intrinsic since we don't want
|
||||
// the documented (& 0x7) on the index. NVCC might be able to optimize it
|
||||
// out since the index is a constexpr, but we choose to be safe about it
|
||||
// here.
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for F16 -> I4 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a fp16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the FP16 to the correct value for the
|
||||
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
|
||||
// where x1 in the high nibble and x0 is the low nibble then using hfma
|
||||
// to subtract 1032 from that
|
||||
// The AND does the following:
|
||||
// 1) Clear the set bits for the int4 we will ignore.
|
||||
// We use lop3 so that we can use 1 instruction for AND and XOR.
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
static constexpr uint32_t and_mask = 0xFFF0FF0F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 hfmas that do the following:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
|
||||
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
|
||||
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
|
||||
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
|
||||
|
||||
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
|
||||
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, vllm_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<vllm_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
|
||||
// - {72, 72}
|
||||
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::half_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t xor_mask = 0x64006400;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
|
||||
auto src_ = src >> (4 * (ii));
|
||||
r[ii + 0] = src_;
|
||||
r[ii + 1] = src_;
|
||||
|
||||
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
static constexpr uint32_t high_nib_mask = 0x00F000F0;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 1])
|
||||
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
|
||||
// For high nibble:
|
||||
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
|
||||
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
|
||||
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
|
||||
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
|
||||
}
|
||||
|
||||
{
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
|
||||
fp16x2_val = __hfma2(fp16x2_val,
|
||||
reinterpret_cast<const half2&>(high_nib_scale),
|
||||
reinterpret_cast<const half2&>(high_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::half_t, N> <= Array<vllm_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::half_t, N>;
|
||||
using source_type = Array<vllm_uint8b128_t, N>;
|
||||
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
// Hold output FP16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(start_byte_for_fp16),
|
||||
"r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
|
||||
static constexpr uint32_t bias_rep = 0x64806480;
|
||||
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hsub2(fp16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::float, N> <= Array<vllm_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
|
||||
using result_type = Array<float, N>;
|
||||
using source_type = Array<vllm_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
PackedResultType r;
|
||||
|
||||
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
|
||||
// u8x4 source and stores the result in r (without introducing extra
|
||||
// cvt.u32.u8 instruction)
|
||||
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
|
||||
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
|
||||
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
|
||||
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
|
||||
// Subtract the magic number 0x4B000000 from tmp in floating-point
|
||||
// arithmetic to obtain final result
|
||||
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
|
||||
}
|
||||
|
||||
return r;
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<vllm_uint4b8_t, N>;
|
||||
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
|
||||
// Hold output BF16s in reg. We need 1 reg for every 2 elements
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
uint32_t src_reg_shifted = src_reg >> 4;
|
||||
|
||||
// Below constructs the following temporary:
|
||||
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
|
||||
static_assert(RegArray::kElements <= 4,
|
||||
"Too many inputs for uint4b8_t -> BF16 vector converter");
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
|
||||
// we are trying to construct x and a BF16 value
|
||||
// The below XOR does the following:
|
||||
// 1) Sets the exponent bits of the BF16 to the correct value for the
|
||||
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
|
||||
// and subtracting 136 to get {x1, x0}
|
||||
static constexpr uint32_t xor_mask = 0x43004300;
|
||||
static constexpr uint32_t and_mask = 0x000F000F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
|
||||
// For each operand, computes:
|
||||
// r[i] = (r[i] & and_mask) ^ xor_mask
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
|
||||
// We will issue 2 bfmas that do the following:
|
||||
// high BF16:
|
||||
// hi_bf16 - 136, lo_bf16 - 136
|
||||
|
||||
// This is the BF16 {136, 136} represented as an integer.
|
||||
static constexpr uint32_t bias_rep = 0x43084308;
|
||||
const __nv_bfloat162& bias =
|
||||
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
bf16x2_val = __hsub2(bf16x2_val, bias);
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, vllm_uint4b8_t, N,
|
||||
Round, void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<vllm_uint4b8_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii + 0])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
|
||||
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
|
||||
// for IlvdLayout: (2, 4):(4, 1)
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
|
||||
cutlass::bfloat16_t, uint4_t, N, Round,
|
||||
void> {
|
||||
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
|
||||
static_assert(N % size(IlvdLayout{}) == 0);
|
||||
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<uint4_t, N>;
|
||||
|
||||
private:
|
||||
struct RegConvert {
|
||||
template <typename PackedResultType>
|
||||
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
|
||||
using RegArray =
|
||||
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
|
||||
sizeof(PackedResultType)>;
|
||||
RegArray r;
|
||||
|
||||
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
|
||||
static constexpr uint32_t or_mask = 0x43004300;
|
||||
|
||||
// Unlike float16 where the mantissa is large enough to contain 2
|
||||
// nibbles, bfloat16 can only fit one, so we can only convert one
|
||||
// nibble at a time
|
||||
for (int ii = 0; ii < RegArray::kElements; ++ii) {
|
||||
r[ii] = src >> (4 * ii);
|
||||
|
||||
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t low_nib_mask = 0x000F000F;
|
||||
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
|
||||
|
||||
// For low nibble:
|
||||
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
|
||||
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
|
||||
|
||||
{
|
||||
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
|
||||
fp16x2_val =
|
||||
__hsub2(fp16x2_val,
|
||||
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
|
||||
}
|
||||
}
|
||||
|
||||
return reinterpret_cast<PackedResultType&>(r);
|
||||
};
|
||||
};
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
|
||||
typename source_type::Element,
|
||||
N>::convert(source);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint8b128_t, N>
|
||||
template <FloatRoundStyle Round, int N>
|
||||
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
|
||||
using result_type = Array<cutlass::bfloat16_t, N>;
|
||||
using source_type = Array<vllm_uint8b128_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
private:
|
||||
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
|
||||
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
|
||||
using src_packed_4_t = Array<vllm_uint8b128_t, 4>;
|
||||
using src_packed_2_t = Array<vllm_uint8b128_t, 2>;
|
||||
|
||||
// Not Valid, not supported, only here to satisfy the interface and to avoid
|
||||
// a compile error. ScalarConverter will not actually work until
|
||||
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
|
||||
// implemented
|
||||
using ScalarConverter =
|
||||
NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round>;
|
||||
|
||||
template <typename PackedResultType, typename PackedSrcType>
|
||||
CUTLASS_DEVICE static PackedResultType packed_convert(
|
||||
PackedSrcType const& source) {
|
||||
static_assert(
|
||||
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
|
||||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
|
||||
platform::is_same<PackedResultType, result_packed_4_t>::value),
|
||||
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
|
||||
"convert dispatch.");
|
||||
|
||||
NumericArrayConverter<float, vllm_uint8b128_t, PackedResultType::kElements,
|
||||
Round>
|
||||
convert_uint8_to_f32;
|
||||
Array<float, PackedResultType::kElements> tmp =
|
||||
convert_uint8_to_f32(source);
|
||||
NumericArrayConverter<cutlass::bfloat16_t, float,
|
||||
PackedResultType::kElements, Round>
|
||||
convert_f32_to_bf16_;
|
||||
return convert_f32_to_bf16_(tmp);
|
||||
}
|
||||
|
||||
friend class detail::VectorizedConverter;
|
||||
|
||||
public:
|
||||
CUTLASS_DEVICE
|
||||
static result_type convert(source_type const& source) {
|
||||
result_type result;
|
||||
using ConverterType =
|
||||
NumericArrayConverter<typename result_type::Element,
|
||||
typename source_type::Element, N, Round>;
|
||||
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
|
||||
src_packed_4_t, result_packed_2_t,
|
||||
src_packed_2_t>(result, source);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
result_type operator()(source_type const& s) const { return convert(s); }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -3,13 +3,16 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
|
||||
variance += x * x;
|
||||
residual[blockIdx.x * hidden_size + idx] = z;
|
||||
}
|
||||
/* Keep the following if-else block in sync with the
|
||||
calculation of max_block_size in fused_add_rms_norm */
|
||||
if (num_tokens < 256) {
|
||||
variance = blockReduceSum<float, 1024>(variance);
|
||||
} else
|
||||
variance = blockReduceSum<float, 256>(variance);
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
|
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
700
csrc/mamba/causal_conv1d/causal_conv1d.cu
Normal file
@ -0,0 +1,700 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template <typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
void* bias_ptr,
|
||||
bool silu_activation) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.x_batch_stride = x.stride(0);
|
||||
params.x_c_stride = x.stride(1);
|
||||
params.x_l_stride = x.stride(-1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_c_stride = out.stride(1);
|
||||
params.out_l_stride = out.stride(-1);
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &seq_idx_,
|
||||
const c10::optional<at::Tensor> &initial_states_,
|
||||
const c10::optional<at::Tensor> &final_states_out_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
||||
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
||||
|
||||
if (is_channel_last) {
|
||||
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
||||
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
||||
}
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
||||
auto seq_idx = seq_idx_.value();
|
||||
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(seq_idx.is_cuda());
|
||||
TORCH_CHECK(seq_idx.is_contiguous());
|
||||
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
|
||||
if (seq_idx_.has_value()) {
|
||||
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
||||
} else {
|
||||
params.seq_idx_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (initial_states_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
||||
auto initial_states = initial_states_.value();
|
||||
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(initial_states.is_cuda());
|
||||
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(initial_states.stride(1) == 1);
|
||||
params.initial_states_ptr = initial_states.data_ptr();
|
||||
params.initial_states_batch_stride = initial_states.stride(0);
|
||||
params.initial_states_c_stride = initial_states.stride(1);
|
||||
params.initial_states_l_stride = initial_states.stride(2);
|
||||
} else {
|
||||
params.initial_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
if (final_states_out_.has_value()) {
|
||||
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
||||
auto final_states = final_states_out_.value();
|
||||
TORCH_CHECK(final_states.scalar_type() == input_type);
|
||||
TORCH_CHECK(final_states.is_cuda());
|
||||
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
||||
TORCH_CHECK(final_states.stride(1) == 1);
|
||||
params.final_states_ptr = final_states.data_ptr();
|
||||
params.final_states_batch_stride = final_states.stride(0);
|
||||
params.final_states_c_stride = final_states.stride(1);
|
||||
params.final_states_l_stride = final_states.stride(2);
|
||||
} else {
|
||||
params.final_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
if (!is_channel_last) {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
}
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor
|
||||
causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
bool silu_activation) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int width = weight.size(-1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim);
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, width);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = torch::empty_like(x);
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
|
||||
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
||||
silu_activation);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t zeros[kNElts] = {0};
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifndef USE_ROCM
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
||||
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
||||
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
||||
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
||||
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static_assert(kNThreads % 32 == 0);
|
||||
static constexpr int kNWarps = kNThreads / 32;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kChunkSizeL = kChunkSizeL_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
||||
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
||||
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
||||
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
||||
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
||||
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
||||
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
||||
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
||||
// sizeof(typename BlockStoreT::TempStorage)});
|
||||
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kHasSeqIdx>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
||||
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int chunk_l_id = blockIdx.y;
|
||||
const int chunk_c_id = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int l_idx = tid / kNThreadsPerC;
|
||||
const int c_idx = tid % kNThreadsPerC;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
||||
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
||||
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
||||
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
||||
// from the previous L-chunk.
|
||||
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
// Load the elements from the previous chunk that are needed for convolution.
|
||||
if (l_idx < kWidth - 1) {
|
||||
input_t x_vals_load[kNElts] = {0};
|
||||
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
||||
} else if (initial_states != nullptr
|
||||
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (final_states != nullptr
|
||||
&& l_idx < kWidth - 1
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
||||
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
||||
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
||||
}
|
||||
|
||||
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
||||
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
||||
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
||||
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
||||
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
||||
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
||||
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
||||
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
||||
static_assert(kNThreadsPerRow <= 32);
|
||||
|
||||
const int row_idx = tid / kNThreadsPerRow;
|
||||
const int col_idx = tid % kNThreadsPerRow;
|
||||
|
||||
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
||||
}
|
||||
}
|
||||
float x_vals[kWidth - 1 + kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
||||
}
|
||||
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
||||
if constexpr (kHasSeqIdx) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
||||
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
||||
}
|
||||
}
|
||||
|
||||
float out_vals[kLPerThread];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
if constexpr (!kHasSeqIdx) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
||||
} else {
|
||||
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
||||
}
|
||||
}
|
||||
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
||||
input_t out_vals_store[kNElts];
|
||||
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
||||
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
||||
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
||||
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
||||
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
||||
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
||||
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
||||
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
||||
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
||||
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
||||
dim3 block(Ktraits::kNThreads);
|
||||
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
||||
// if (kSmemSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
// }
|
||||
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
///////
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
}
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if (channel_id < params.dim) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
|
||||
x_vals[kWidth - 1] = float(x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
|
||||
}
|
||||
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
if (channel_id < params.dim) { out[0] = input_t(out_val); }
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
144
csrc/mamba/causal_conv1d/causal_conv1d.h
Normal file
@ -0,0 +1,144 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
28
csrc/mamba/causal_conv1d/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
276
csrc/mamba/mamba_ssm/selective_scan.h
Normal file
@ -0,0 +1,276 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, dstate, n_groups, n_chunks;
|
||||
int dim_ngroups_ratio;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
|
||||
bool delta_softplus;
|
||||
|
||||
index_t A_d_stride;
|
||||
index_t A_dstate_stride;
|
||||
index_t B_batch_stride;
|
||||
index_t B_d_stride;
|
||||
index_t B_dstate_stride;
|
||||
index_t B_group_stride;
|
||||
index_t C_batch_stride;
|
||||
index_t C_d_stride;
|
||||
index_t C_dstate_stride;
|
||||
index_t C_group_stride;
|
||||
index_t u_batch_stride;
|
||||
index_t u_d_stride;
|
||||
index_t delta_batch_stride;
|
||||
index_t delta_d_stride;
|
||||
index_t z_batch_stride;
|
||||
index_t z_d_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_d_stride;
|
||||
index_t out_z_batch_stride;
|
||||
index_t out_z_d_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ A_ptr;
|
||||
void *__restrict__ B_ptr;
|
||||
void *__restrict__ C_ptr;
|
||||
void *__restrict__ D_ptr;
|
||||
void *__restrict__ u_ptr;
|
||||
void *__restrict__ delta_ptr;
|
||||
void *__restrict__ delta_bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ z_ptr;
|
||||
void *__restrict__ out_z_ptr;
|
||||
void *__restrict__ index_ptr;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#define MAX_DSTATE 256
|
||||
|
||||
|
||||
inline __device__ float2 operator+(const float2 & a, const float2 & b){
|
||||
return {a.x + b.x, a.y + b.y};
|
||||
}
|
||||
|
||||
inline __device__ float3 operator+(const float3 &a, const float3 &b) {
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z};
|
||||
}
|
||||
|
||||
inline __device__ float4 operator+(const float4 & a, const float4 & b){
|
||||
return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w};
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename scalar_t, int N>
|
||||
struct Converter{
|
||||
static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; ++i) { dst[i] = src[i]; }
|
||||
}
|
||||
};
|
||||
|
||||
template<int N>
|
||||
struct Converter<at::Half, N>{
|
||||
static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const half2 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
template<int N>
|
||||
struct Converter<at::BFloat16, N>{
|
||||
static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) {
|
||||
static_assert(N % 2 == 0);
|
||||
auto &src2 = reinterpret_cast<const nv_bfloat162 (&)[N / 2]>(src);
|
||||
auto &dst2 = reinterpret_cast<float2 (&)[N / 2]>(dst);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); }
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
template<typename scalar_t> struct SSMScanOp;
|
||||
|
||||
template<>
|
||||
struct SSMScanOp<float> {
|
||||
__device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const {
|
||||
return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y);
|
||||
}
|
||||
};
|
||||
|
||||
// A stateful callback functor that maintains a running prefix to be applied
|
||||
// during consecutive scan operations.
|
||||
template <typename scalar_t> struct SSMScanPrefixCallbackOp {
|
||||
using scan_t = std::conditional_t<std::is_same_v<scalar_t, float>, float2, float4>;
|
||||
scan_t running_prefix;
|
||||
// Constructor
|
||||
__device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {}
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
|
||||
__device__ scan_t operator()(scan_t block_aggregate) {
|
||||
scan_t old_prefix = running_prefix;
|
||||
running_prefix = SSMScanOp<scalar_t>()(running_prefix, block_aggregate);
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_input(typename Ktraits::input_t *u,
|
||||
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadT::TempStorage &smem_load,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
|
||||
reinterpret_cast<vec_t*>(u),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(u_vals)
|
||||
#ifdef USE_ROCM
|
||||
, Ktraits::kNThreads * Ktraits::kNLoads
|
||||
#endif
|
||||
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_index(int *u,
|
||||
int (&u_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
|
||||
int seqlen) {
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
|
||||
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
|
||||
reinterpret_cast<uint4*>(u),
|
||||
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
|
||||
);
|
||||
} else {
|
||||
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
|
||||
typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight,
|
||||
int seqlen) {
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
typename Ktraits::input_t B_vals_load[kNItems];
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
|
||||
reinterpret_cast<vec_t*>(Bvar),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(B_vals_load)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f);
|
||||
}
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; }
|
||||
Converter<typename Ktraits::input_t, kNItems>::to_float(B_vals_load, B_vals);
|
||||
}
|
||||
|
||||
template<typename Ktraits>
|
||||
inline __device__ void store_output(typename Ktraits::input_t *out,
|
||||
const float (&out_vals)[Ktraits::kNItems],
|
||||
typename Ktraits::BlockStoreT::TempStorage &smem_store,
|
||||
int seqlen) {
|
||||
typename Ktraits::input_t write_vals[Ktraits::kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
|
||||
if constexpr (Ktraits::kIsEvenLen) {
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
|
||||
reinterpret_cast<vec_t*>(out),
|
||||
reinterpret_cast<vec_t(&)[Ktraits::kNLoads]>(write_vals)
|
||||
);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen);
|
||||
}
|
||||
}
|
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
593
csrc/mamba/mamba_ssm/selective_scan_fwd.cu
Normal file
@ -0,0 +1,593 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "selective_scan.h"
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "selective_scan.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
|
||||
bool kIsVariableB_, bool kIsVariableC_,
|
||||
bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_>
|
||||
struct Selective_Scan_fwd_kernel_traits {
|
||||
static_assert(kNItems_ % 4 == 0);
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
// Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy.
|
||||
static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3;
|
||||
static constexpr int kNItems = kNItems_;
|
||||
static constexpr int kNRows = kNRows_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
|
||||
static_assert(kNItems % kNElts == 0);
|
||||
static constexpr int kNLoads = kNItems / kNElts;
|
||||
static constexpr bool kIsEvenLen = kIsEvenLen_;
|
||||
static constexpr bool kIsVariableB = kIsVariableB_;
|
||||
static constexpr bool kIsVariableC = kIsVariableC_;
|
||||
static constexpr bool kHasZ = kHasZ_;
|
||||
static constexpr bool kUseIndex = kUseIndex_;
|
||||
|
||||
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
|
||||
static constexpr int kNLoadsIndex = kNItems / 4;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using scan_t = float2;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
|
||||
!(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
|
||||
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, kNLoads,
|
||||
!kDirectIO ? cub::BLOCK_STORE_WARP_TRANSPOSE : cub::BLOCK_STORE_DIRECT>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING_MEMOIZE>;
|
||||
// using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_RAKING>;
|
||||
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
|
||||
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
|
||||
sizeof(typename BlockLoadVecT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexT::TempStorage),
|
||||
sizeof(typename BlockLoadIndexVecT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
|
||||
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
|
||||
sizeof(typename BlockStoreT::TempStorage),
|
||||
sizeof(typename BlockStoreVecT::TempStorage)});
|
||||
static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage);
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks)
|
||||
void selective_scan_fwd_kernel(SSMParamsBase params) {
|
||||
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
|
||||
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
|
||||
constexpr bool kHasZ = Ktraits::kHasZ;
|
||||
constexpr bool kUseIndex = Ktraits::kUseIndex;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNItems = Ktraits::kNItems;
|
||||
constexpr int kNRows = Ktraits::kNRows;
|
||||
constexpr bool kDirectIO = Ktraits::kDirectIO;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
using scan_t = typename Ktraits::scan_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// cast to lvalue reference of expected type
|
||||
// char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t);
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_ + 2 * MAX_DSTATE * sizeof(weight_t));
|
||||
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
|
||||
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
|
||||
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
||||
// weight_t *smem_a = reinterpret_cast<weight_t *>(smem_ + smem_loadstorescan_size);
|
||||
// weight_t *smem_bc = reinterpret_cast<weight_t *>(smem_a + MAX_DSTATE);
|
||||
scan_t *smem_running_prefix = reinterpret_cast<scan_t *>(smem_ + Ktraits::kSmemSize);
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int dim_id = blockIdx.y;
|
||||
const int group_id = dim_id / (params.dim_ngroups_ratio);
|
||||
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
|
||||
+ dim_id * kNRows * params.u_d_stride;
|
||||
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
|
||||
+ dim_id * kNRows * params.delta_d_stride;
|
||||
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
|
||||
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
|
||||
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
|
||||
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
|
||||
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
|
||||
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
|
||||
int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
|
||||
|
||||
float D_val[kNRows] = {0};
|
||||
if (params.D_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
D_val[r] = reinterpret_cast<float *>(params.D_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
float delta_bias[kNRows] = {0};
|
||||
if (params.delta_bias_ptr != nullptr) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
delta_bias[r] = reinterpret_cast<float *>(params.delta_bias_ptr)[dim_id * kNRows + r];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) {
|
||||
// smem_a[state_idx] = A[state_idx * params.A_dstate_stride];
|
||||
// smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride];
|
||||
// }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNItems;
|
||||
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
|
||||
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
|
||||
int index_vals_load[kNRows][kNItems];
|
||||
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (!kDirectIO) { __syncthreads(); }
|
||||
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
|
||||
if constexpr (kUseIndex) {
|
||||
load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
if constexpr (kUseIndex) {
|
||||
index += kChunkSize;
|
||||
}
|
||||
u += kChunkSize;
|
||||
delta += kChunkSize;
|
||||
|
||||
float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float u_val = float(u_vals[r][i]);
|
||||
delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r];
|
||||
if (params.delta_softplus) {
|
||||
delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i];
|
||||
}
|
||||
delta_u_vals[r][i] = delta_vals[r][i] * u_val;
|
||||
out_vals[r][i] = D_val[r] * u_val;
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
|
||||
weight_t A_val[kNRows];
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride];
|
||||
// Multiply the real part of A with LOG2E so we can use exp2f instead of expf.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
A_val[r] *= kLog2e;
|
||||
}
|
||||
// This variable holds B * C if both B and C are constant across seqlen. If only B varies
|
||||
// across seqlen, this holds C. If only C varies across seqlen, this holds B.
|
||||
// If both B and C vary, this is unused.
|
||||
weight_t BC_val[kNRows];
|
||||
weight_t B_vals[kNItems], C_vals[kNItems];
|
||||
if constexpr (kIsVariableB) {
|
||||
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
|
||||
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1));
|
||||
if constexpr (!kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (kIsVariableC) {
|
||||
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
|
||||
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
|
||||
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 ));
|
||||
if constexpr (!kIsVariableB) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!kIsVariableB && !kIsVariableC) {
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if (r > 0) { __syncthreads(); } // Scan could be using the same smem
|
||||
scan_t thread_data[kNItems];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
|
||||
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
|
||||
|
||||
// Reset A bar for cumulative sequences (Real)
|
||||
if constexpr (kUseIndex) {
|
||||
if (index_vals_load[r][i] == 0) {
|
||||
thread_data[i].x = 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
|
||||
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
|
||||
thread_data[i] = make_float2(1.f, 0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Initialize running total
|
||||
scan_t running_prefix;
|
||||
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
|
||||
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
|
||||
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
|
||||
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
|
||||
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
|
||||
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
|
||||
);
|
||||
// There's a syncthreads in the scan op, so we don't need to sync here.
|
||||
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
|
||||
if (threadIdx.x == 0) {
|
||||
smem_running_prefix[state_idx] = prefix_op.running_prefix;
|
||||
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
const weight_t C_val = !kIsVariableC
|
||||
? BC_val[r]
|
||||
: (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]);
|
||||
out_vals[r][i] += thread_data[i].y * C_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
if constexpr (!kDirectIO) {
|
||||
if (r > 0) { __syncthreads(); }
|
||||
}
|
||||
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
|
||||
if constexpr (kHasZ) {
|
||||
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
|
||||
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
|
||||
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
|
||||
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
|
||||
#pragma unroll
|
||||
for (int r = 0; r < kNRows; ++r) {
|
||||
input_t z_vals[kNItems];
|
||||
__syncthreads();
|
||||
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNItems; ++i) {
|
||||
float z_val = z_vals[i];
|
||||
out_vals[r][i] *= z_val / (1 + expf(-z_val));
|
||||
}
|
||||
__syncthreads();
|
||||
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
|
||||
}
|
||||
}
|
||||
|
||||
Bvar += kChunkSize * 1;
|
||||
Cvar += kChunkSize * 1;
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kNItems, typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block
|
||||
// processing 1 row.
|
||||
constexpr int kNRows = 1;
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
|
||||
#ifndef USE_ROCM
|
||||
if (params.seqlen <= 128) {
|
||||
selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#else
|
||||
if (params.seqlen <= 256) {
|
||||
selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 512) {
|
||||
selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream);
|
||||
} else if (params.seqlen <= 1024) {
|
||||
selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream);
|
||||
} else {
|
||||
selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template void selective_scan_fwd_cuda<at::BFloat16, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<at::Half, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
template void selective_scan_fwd_cuda<float, float>(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t dstate,
|
||||
const size_t n_groups,
|
||||
const size_t n_chunks,
|
||||
const bool is_variable_B,
|
||||
const bool is_variable_C,
|
||||
// device pointers
|
||||
const torch::Tensor u,
|
||||
const torch::Tensor delta,
|
||||
const torch::Tensor A,
|
||||
const torch::Tensor B,
|
||||
const torch::Tensor C,
|
||||
const torch::Tensor out,
|
||||
const torch::Tensor z,
|
||||
const torch::Tensor out_z,
|
||||
void* D_ptr,
|
||||
void* delta_bias_ptr,
|
||||
void* x_ptr,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
void* index_ptr) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.dstate = dstate;
|
||||
params.n_groups = n_groups;
|
||||
params.n_chunks = n_chunks;
|
||||
params.dim_ngroups_ratio = dim / n_groups;
|
||||
|
||||
params.delta_softplus = delta_softplus;
|
||||
|
||||
params.is_variable_B = is_variable_B;
|
||||
params.is_variable_C = is_variable_C;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.u_ptr = u.data_ptr();
|
||||
params.delta_ptr = delta.data_ptr();
|
||||
params.A_ptr = A.data_ptr();
|
||||
params.B_ptr = B.data_ptr();
|
||||
params.C_ptr = C.data_ptr();
|
||||
params.D_ptr = D_ptr;
|
||||
params.delta_bias_ptr = delta_bias_ptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
params.x_ptr = x_ptr;
|
||||
params.z_ptr = has_z ? z.data_ptr() : nullptr;
|
||||
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
|
||||
|
||||
params.index_ptr = index_ptr;
|
||||
|
||||
// All stride are in elements, not bytes.
|
||||
params.A_d_stride = A.stride(0);
|
||||
params.A_dstate_stride = A.stride(1);
|
||||
if (!is_variable_B) {
|
||||
params.B_d_stride = B.stride(0);
|
||||
} else {
|
||||
params.B_batch_stride = B.stride(0);
|
||||
params.B_group_stride = B.stride(1);
|
||||
}
|
||||
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
|
||||
if (!is_variable_C) {
|
||||
params.C_d_stride = C.stride(0);
|
||||
} else {
|
||||
params.C_batch_stride = C.stride(0);
|
||||
params.C_group_stride = C.stride(1);
|
||||
}
|
||||
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
|
||||
params.u_batch_stride = u.stride(0);
|
||||
params.u_d_stride = u.stride(1);
|
||||
params.delta_batch_stride = delta.stride(0);
|
||||
params.delta_d_stride = delta.stride(1);
|
||||
if (has_z) {
|
||||
params.z_batch_stride = z.stride(0);
|
||||
params.z_d_stride = z.stride(1);
|
||||
params.out_z_batch_stride = out_z.stride(0);
|
||||
params.out_z_d_stride = out_z.stride(1);
|
||||
}
|
||||
params.out_batch_stride = out.stride(0);
|
||||
params.out_d_stride = out.stride(1);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor>
|
||||
selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||
const c10::optional<torch::Tensor> &D_,
|
||||
const c10::optional<torch::Tensor> &z_,
|
||||
const c10::optional<torch::Tensor> &delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor> &index_,
|
||||
const c10::optional<torch::Tensor> &x) {
|
||||
auto input_type = u.scalar_type();
|
||||
auto weight_type = A.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float);
|
||||
|
||||
const bool is_variable_B = B.dim() >= 3;
|
||||
const bool is_variable_C = C.dim() >= 3;
|
||||
|
||||
TORCH_CHECK(delta.scalar_type() == input_type);
|
||||
TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type));
|
||||
TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type));
|
||||
|
||||
TORCH_CHECK(u.is_cuda());
|
||||
TORCH_CHECK(delta.is_cuda());
|
||||
TORCH_CHECK(A.is_cuda());
|
||||
TORCH_CHECK(B.is_cuda());
|
||||
TORCH_CHECK(C.is_cuda());
|
||||
|
||||
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
|
||||
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
|
||||
|
||||
const auto sizes = u.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int dstate = A.size(1);
|
||||
const int n_groups = is_variable_B ? B.size(1) : 1;
|
||||
|
||||
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
|
||||
|
||||
CHECK_SHAPE(u, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(delta, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(A, dim, dstate);
|
||||
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen );
|
||||
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
|
||||
|
||||
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
|
||||
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
|
||||
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
|
||||
|
||||
if (D_.has_value()) {
|
||||
auto D = D_.value();
|
||||
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(D.is_cuda());
|
||||
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
|
||||
CHECK_SHAPE(D, dim);
|
||||
}
|
||||
|
||||
if (delta_bias_.has_value()) {
|
||||
auto delta_bias = delta_bias_.value();
|
||||
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
|
||||
TORCH_CHECK(delta_bias.is_cuda());
|
||||
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
|
||||
CHECK_SHAPE(delta_bias, dim);
|
||||
}
|
||||
if (index_.has_value()) {
|
||||
auto index = index_.value();
|
||||
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(index.is_cuda());
|
||||
CHECK_SHAPE(index, batch_size, seqlen);
|
||||
}
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
out_z = torch::empty_like(z);
|
||||
|
||||
const int n_chunks = (seqlen + 2048 - 1) / 2048;
|
||||
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
|
||||
// at::Tensor out = torch::empty_like(u);
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = torch::empty_like(delta);
|
||||
if (x.has_value()){
|
||||
auto _x = x.value();
|
||||
TORCH_CHECK(_x.scalar_type() == weight_type);
|
||||
TORCH_CHECK(_x.is_cuda());
|
||||
TORCH_CHECK(_x.stride(-1) == 1);
|
||||
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
|
||||
}
|
||||
|
||||
SSMParamsBase params;
|
||||
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
|
||||
u, delta, A, B, C, out, z, out_z,
|
||||
D_.has_value() ? D_.value().data_ptr() : nullptr,
|
||||
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
|
||||
x.value().data_ptr(),
|
||||
has_z,
|
||||
delta_softplus,
|
||||
index_.has_value() ? index_.value().data_ptr() : nullptr);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
std::vector<at::Tensor> result = {out, x.value()};
|
||||
if (has_z) { result.push_back(out_z); }
|
||||
return result;
|
||||
}
|
||||
|
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
28
csrc/mamba/mamba_ssm/static_switch.h
Normal file
@ -0,0 +1,28 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
1740
csrc/moe/marlin_moe_ops.cu
Normal file
1740
csrc/moe/marlin_moe_ops.cu
Normal file
File diff suppressed because it is too large
Load Diff
12
csrc/moe/marlin_moe_ops.h
Normal file
12
csrc/moe/marlin_moe_ops.h
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
torch::Tensor marlin_gemm_moe(
|
||||
const torch::Tensor& a, const torch::Tensor& b_q_weights,
|
||||
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
|
||||
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
|
||||
const torch::Tensor& g_idx, const torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size,
|
||||
bool replicate_input, bool apply_weights);
|
@ -1,5 +1,6 @@
|
||||
#include "registration.h"
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
#include "marlin_moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
// Apply topk softmax to the gating outputs.
|
||||
@ -7,6 +8,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
m.def(
|
||||
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
|
||||
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
|
||||
"g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int "
|
||||
"size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, "
|
||||
"bool replicate_input, bool apply_weights) -> Tensor");
|
||||
|
||||
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
84
csrc/ops.h
84
csrc/ops.h
@ -3,6 +3,8 @@
|
||||
#include <optional>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
@ -61,12 +63,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& scales,
|
||||
const torch::Tensor& codebook_partition_sizes,
|
||||
const std::vector<int64_t>& codebook_partition_sizes,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
const torch::Tensor& codebook_partition_sizes);
|
||||
torch::Tensor aqlm_dequant(
|
||||
const torch::Tensor& codes, const torch::Tensor& codebooks,
|
||||
const std::vector<int64_t>& codebook_partition_sizes);
|
||||
|
||||
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors, torch::Tensor _zeros,
|
||||
@ -81,19 +83,41 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k);
|
||||
|
||||
namespace machete {
|
||||
|
||||
std::vector<std::string> supported_schedules(
|
||||
vllm::ScalarTypeTorchPtr const& btype);
|
||||
|
||||
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
vllm::ScalarTypeTorchPtr const& btype,
|
||||
c10::optional<torch::Tensor> const& scales,
|
||||
c10::optional<torch::Tensor> const& zeros,
|
||||
c10::optional<int64_t> group_size,
|
||||
c10::optional<torch::Tensor> const& C,
|
||||
c10::optional<double> alpha, c10::optional<double> beta,
|
||||
c10::optional<std::string> schedule);
|
||||
|
||||
torch::Tensor prepack_B(torch::Tensor const& B,
|
||||
vllm::ScalarTypeTorchPtr const& btype);
|
||||
|
||||
}; // namespace machete
|
||||
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp);
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
@ -102,6 +126,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
|
||||
int64_t n);
|
||||
|
||||
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
|
||||
int64_t type, int64_t row);
|
||||
|
||||
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
|
||||
int64_t row);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
@ -114,6 +147,21 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||
torch::Tensor const& b_q_weight,
|
||||
torch::Tensor const& s_tok,
|
||||
torch::Tensor const& s_ch,
|
||||
torch::Tensor const& s_group,
|
||||
torch::Tensor& workspace, int64_t size_m,
|
||||
int64_t size_n, int64_t size_k);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
@ -147,6 +195,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
std::vector<torch::Tensor> selective_scan_fwd(
|
||||
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
||||
const torch::Tensor& B, const torch::Tensor& C,
|
||||
const c10::optional<torch::Tensor>& D_,
|
||||
const c10::optional<torch::Tensor>& z_,
|
||||
const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
||||
const c10::optional<torch::Tensor>& index_,
|
||||
const c10::optional<torch::Tensor>& x);
|
||||
|
||||
at::Tensor causal_conv1d_update(const at::Tensor& x,
|
||||
const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
bool silu_activation);
|
||||
|
||||
at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& seq_idx_,
|
||||
const c10::optional<at::Tensor>& initial_states_,
|
||||
const c10::optional<at::Tensor>& final_states_out_,
|
||||
bool silu_activation);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
||||
|
@ -1,217 +0,0 @@
|
||||
Contains code from https://github.com/punica-ai/punica
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
|
||||
|
||||
Apache-2.0
|
||||
* third_party/nvbench (with LLVM exception)
|
||||
* third_party/flashinfer
|
||||
|
||||
BSD-3-Clause:
|
||||
* third_party/cutlass
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
|
@ -1,218 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale);
|
||||
|
||||
// clang-format off
|
||||
|
||||
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 640) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 896) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
f(in_T, out_T, W_T, narrow, 1216) \
|
||||
f(in_T, out_T, W_T, narrow, 1280) \
|
||||
f(in_T, out_T, W_T, narrow, 1536) \
|
||||
f(in_T, out_T, W_T, narrow, 1664) \
|
||||
f(in_T, out_T, W_T, narrow, 1728) \
|
||||
f(in_T, out_T, W_T, narrow, 1792) \
|
||||
f(in_T, out_T, W_T, narrow, 2048) \
|
||||
f(in_T, out_T, W_T, narrow, 2240) \
|
||||
f(in_T, out_T, W_T, narrow, 2304) \
|
||||
f(in_T, out_T, W_T, narrow, 2368) \
|
||||
f(in_T, out_T, W_T, narrow, 2432) \
|
||||
f(in_T, out_T, W_T, narrow, 2560) \
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 2816) \
|
||||
f(in_T, out_T, W_T, narrow, 3072) \
|
||||
f(in_T, out_T, W_T, narrow, 3328) \
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 3712) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
f(in_T, out_T, W_T, narrow, 4480) \
|
||||
f(in_T, out_T, W_T, narrow, 4608) \
|
||||
f(in_T, out_T, W_T, narrow, 4736) \
|
||||
f(in_T, out_T, W_T, narrow, 4864) \
|
||||
f(in_T, out_T, W_T, narrow, 5120) \
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 5888) \
|
||||
f(in_T, out_T, W_T, narrow, 6144) \
|
||||
f(in_T, out_T, W_T, narrow, 6400) \
|
||||
f(in_T, out_T, W_T, narrow, 6848) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
f(in_T, out_T, W_T, narrow, 7424) \
|
||||
f(in_T, out_T, W_T, narrow, 8192) \
|
||||
f(in_T, out_T, W_T, narrow, 8960) \
|
||||
f(in_T, out_T, W_T, narrow, 9216) \
|
||||
f(in_T, out_T, W_T, narrow, 9472) \
|
||||
f(in_T, out_T, W_T, narrow, 10240) \
|
||||
f(in_T, out_T, W_T, narrow, 11008) \
|
||||
f(in_T, out_T, W_T, narrow, 11264) \
|
||||
f(in_T, out_T, W_T, narrow, 12288) \
|
||||
f(in_T, out_T, W_T, narrow, 13696) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 14784) \
|
||||
f(in_T, out_T, W_T, narrow, 14848) \
|
||||
f(in_T, out_T, W_T, narrow, 15360) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 18944) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
f(in_T, out_T, W_T, narrow, 22528) \
|
||||
f(in_T, out_T, W_T, narrow, 24576) \
|
||||
f(in_T, out_T, W_T, narrow, 27392) \
|
||||
f(in_T, out_T, W_T, narrow, 27648) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 29568) \
|
||||
f(in_T, out_T, W_T, narrow, 29696) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
f(in_T, out_T, W_T, narrow, 32512) \
|
||||
f(in_T, out_T, W_T, narrow, 32768) \
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 43264) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
f(in_T, out_T, W_T, narrow, 49408) \
|
||||
f(in_T, out_T, W_T, narrow, 60544) \
|
||||
f(in_T, out_T, W_T, narrow, 60672) \
|
||||
f(in_T, out_T, W_T, narrow, 64000) \
|
||||
f(in_T, out_T, W_T, narrow, 64256) \
|
||||
f(in_T, out_T, W_T, narrow, 64512) \
|
||||
f(in_T, out_T, W_T, narrow, 102400) \
|
||||
f(in_T, out_T, W_T, narrow, 102656) \
|
||||
f(in_T, out_T, W_T, narrow, 102912) \
|
||||
f(in_T, out_T, W_T, narrow, 128000) \
|
||||
f(in_T, out_T, W_T, narrow, 128256) \
|
||||
f(in_T, out_T, W_T, narrow, 128512) \
|
||||
|
||||
|
||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||
// and vllm/tests/lora/test_punica.py
|
||||
|
||||
// Used for defining kernels going from the variety of
|
||||
// dim in to the narrow dim out
|
||||
// Using it for the fully sharded column
|
||||
// parallel LoRA A which splits the rank dim
|
||||
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, 128, narrow) \
|
||||
f(in_T, out_T, W_T, 256, narrow) \
|
||||
f(in_T, out_T, W_T, 512, narrow) \
|
||||
f(in_T, out_T, W_T, 640, narrow) \
|
||||
f(in_T, out_T, W_T, 768, narrow) \
|
||||
f(in_T, out_T, W_T, 896, narrow) \
|
||||
f(in_T, out_T, W_T, 1024, narrow) \
|
||||
f(in_T, out_T, W_T, 1152, narrow) \
|
||||
f(in_T, out_T, W_T, 1216, narrow) \
|
||||
f(in_T, out_T, W_T, 1280, narrow) \
|
||||
f(in_T, out_T, W_T, 1536, narrow) \
|
||||
f(in_T, out_T, W_T, 1664, narrow) \
|
||||
f(in_T, out_T, W_T, 1728, narrow) \
|
||||
f(in_T, out_T, W_T, 1792, narrow) \
|
||||
f(in_T, out_T, W_T, 2048, narrow) \
|
||||
f(in_T, out_T, W_T, 2240, narrow) \
|
||||
f(in_T, out_T, W_T, 2304, narrow) \
|
||||
f(in_T, out_T, W_T, 2368, narrow) \
|
||||
f(in_T, out_T, W_T, 2432, narrow) \
|
||||
f(in_T, out_T, W_T, 2560, narrow) \
|
||||
f(in_T, out_T, W_T, 2752, narrow) \
|
||||
f(in_T, out_T, W_T, 2816, narrow) \
|
||||
f(in_T, out_T, W_T, 3072, narrow) \
|
||||
f(in_T, out_T, W_T, 3328, narrow) \
|
||||
f(in_T, out_T, W_T, 3456, narrow) \
|
||||
f(in_T, out_T, W_T, 3584, narrow) \
|
||||
f(in_T, out_T, W_T, 3712, narrow) \
|
||||
f(in_T, out_T, W_T, 4096, narrow) \
|
||||
f(in_T, out_T, W_T, 4480, narrow) \
|
||||
f(in_T, out_T, W_T, 4608, narrow) \
|
||||
f(in_T, out_T, W_T, 4736, narrow) \
|
||||
f(in_T, out_T, W_T, 4864, narrow) \
|
||||
f(in_T, out_T, W_T, 5120, narrow) \
|
||||
f(in_T, out_T, W_T, 5504, narrow) \
|
||||
f(in_T, out_T, W_T, 5632, narrow) \
|
||||
f(in_T, out_T, W_T, 5888, narrow) \
|
||||
f(in_T, out_T, W_T, 6144, narrow) \
|
||||
f(in_T, out_T, W_T, 6400, narrow) \
|
||||
f(in_T, out_T, W_T, 6848, narrow) \
|
||||
f(in_T, out_T, W_T, 6912, narrow) \
|
||||
f(in_T, out_T, W_T, 7168, narrow) \
|
||||
f(in_T, out_T, W_T, 7424, narrow) \
|
||||
f(in_T, out_T, W_T, 8192, narrow) \
|
||||
f(in_T, out_T, W_T, 8960, narrow) \
|
||||
f(in_T, out_T, W_T, 9216, narrow) \
|
||||
f(in_T, out_T, W_T, 9472, narrow) \
|
||||
f(in_T, out_T, W_T, 10240, narrow) \
|
||||
f(in_T, out_T, W_T, 11008, narrow) \
|
||||
f(in_T, out_T, W_T, 11264, narrow) \
|
||||
f(in_T, out_T, W_T, 12288, narrow) \
|
||||
f(in_T, out_T, W_T, 13696, narrow) \
|
||||
f(in_T, out_T, W_T, 13824, narrow) \
|
||||
f(in_T, out_T, W_T, 14336, narrow) \
|
||||
f(in_T, out_T, W_T, 14784, narrow) \
|
||||
f(in_T, out_T, W_T, 14848, narrow) \
|
||||
f(in_T, out_T, W_T, 15360, narrow) \
|
||||
f(in_T, out_T, W_T, 16384, narrow) \
|
||||
f(in_T, out_T, W_T, 18944, narrow) \
|
||||
f(in_T, out_T, W_T, 20480, narrow) \
|
||||
f(in_T, out_T, W_T, 22016, narrow) \
|
||||
f(in_T, out_T, W_T, 22528, narrow) \
|
||||
f(in_T, out_T, W_T, 24576, narrow) \
|
||||
f(in_T, out_T, W_T, 27392, narrow) \
|
||||
f(in_T, out_T, W_T, 27648, narrow) \
|
||||
f(in_T, out_T, W_T, 28672, narrow) \
|
||||
f(in_T, out_T, W_T, 29568, narrow) \
|
||||
f(in_T, out_T, W_T, 29696, narrow) \
|
||||
f(in_T, out_T, W_T, 32000, narrow) \
|
||||
f(in_T, out_T, W_T, 32256, narrow) \
|
||||
f(in_T, out_T, W_T, 32512, narrow) \
|
||||
f(in_T, out_T, W_T, 32768, narrow) \
|
||||
f(in_T, out_T, W_T, 33024, narrow) \
|
||||
f(in_T, out_T, W_T, 36864, narrow) \
|
||||
f(in_T, out_T, W_T, 43264, narrow) \
|
||||
f(in_T, out_T, W_T, 49152, narrow) \
|
||||
f(in_T, out_T, W_T, 49408, narrow) \
|
||||
f(in_T, out_T, W_T, 60544, narrow) \
|
||||
f(in_T, out_T, W_T, 60672, narrow) \
|
||||
f(in_T, out_T, W_T, 64000, narrow) \
|
||||
f(in_T, out_T, W_T, 64256, narrow) \
|
||||
f(in_T, out_T, W_T, 64512, narrow) \
|
||||
f(in_T, out_T, W_T, 102400, narrow) \
|
||||
f(in_T, out_T, W_T, 102656, narrow) \
|
||||
f(in_T, out_T, W_T, 102912, narrow) \
|
||||
f(in_T, out_T, W_T, 128000, narrow) \
|
||||
f(in_T, out_T, W_T, 128256, narrow) \
|
||||
f(in_T, out_T, W_T, 128512, narrow) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||
|
||||
|
||||
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
|
||||
f(in_T, out_T, W_T, 8, 64) \
|
||||
f(in_T, out_T, W_T, 16, 64) \
|
||||
f(in_T, out_T, W_T, 32, 64) \
|
||||
f(in_T, out_T, W_T, 64, 64)
|
||||
|
||||
// clang-format on
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user