mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
527 Commits
Author | SHA1 | Date | |
---|---|---|---|
51c31bc10c | |||
3ad438c66f | |||
203d4f82ac | |||
991143cfcd | |||
8b2d3cbc1b | |||
9765b5c406 | |||
430530fc18 | |||
97356f3c7e | |||
f510395bbf | |||
6110c39dc8 | |||
d8658c8cc1 | |||
7bc94a0fdd | |||
756b30a5f3 | |||
395aa823ea | |||
26422e477b | |||
f342153b48 | |||
27a57cad52 | |||
98a42e7078 | |||
0267fef52a | |||
4716a32dd4 | |||
c0935c96d3 | |||
cb40b3ab6b | |||
515386ef3c | |||
a4075cba4d | |||
96aa014d1e | |||
1715056fef | |||
b51c1cc9d2 | |||
ce567a2926 | |||
d6ea427f04 | |||
14ccd94c89 | |||
8267b06c30 | |||
3492859b68 | |||
098e1776ba | |||
10e6322283 | |||
6d9aa00fc4 | |||
1182607e18 | |||
45b6ef6513 | |||
1956931436 | |||
e24336b5a7 | |||
d18f4e73f3 | |||
82c540bebf | |||
8f44facddd | |||
e66b629c04 | |||
76879342a3 | |||
566b57c5c4 | |||
0dc72273b8 | |||
a979d9771e | |||
8af890a865 | |||
dfeb2ecc3a | |||
3a243095e5 | |||
64172a976c | |||
f408d05c52 | |||
0b4997e05c | |||
c13ad1b7bd | |||
819924e749 | |||
01bfb22b41 | |||
e67c295b0c | |||
925f3332ca | |||
b0dfa91dd7 | |||
56a8652f33 | |||
6d93d35308 | |||
837e185142 | |||
42bc386129 | |||
8b268a46a7 | |||
41deac4a3d | |||
af9e53496f | |||
f8a12ecc7f | |||
3c5ab9b811 | |||
743a0b7402 | |||
bfdb1ba5c3 | |||
cf2f084d56 | |||
f721096d48 | |||
e90fc21f2e | |||
ea5f14e6ff | |||
b7050ca7df | |||
c188ecb080 | |||
865732342b | |||
4c07dd28c0 | |||
3bbff9e5ab | |||
6ebd02bdef | |||
523e30ea0c | |||
f1c0fc3919 | |||
6e435de766 | |||
426ec4ec67 | |||
80e254834d | |||
ba8ae1d84f | |||
84eaa68425 | |||
5ee14494e4 | |||
4ad521d8b5 | |||
9474e89ba4 | |||
20478c4d3a | |||
63e8b28a99 | |||
cc63d03fbb | |||
2a60c9bd17 | |||
c614cfee58 | |||
7341c77d69 | |||
ef65dcfa6f | |||
6a9c583e73 | |||
b37cdce2b1 | |||
b30880a762 | |||
49eedea373 | |||
9fdf3de346 | |||
c0c17d4896 | |||
097aa0ea22 | |||
482b0adf1b | |||
8c654c045f | |||
9101d832e6 | |||
93348d9458 | |||
abfc4f3387 | |||
6b78837b29 | |||
120157fd2a | |||
8e67598aa6 | |||
ad50bf4b25 | |||
cf6ff18246 | |||
14e3f9a1b2 | |||
3123f15138 | |||
413366e9a2 | |||
10585e035e | |||
fb96c1e98c | |||
8fa7357f2d | |||
a7af4538ca | |||
604f235937 | |||
14b8ae02e7 | |||
03d37f2441 | |||
a7c871680e | |||
429284dc37 | |||
253a98078a | |||
21539e6856 | |||
b522c4476f | |||
78b6c4845a | |||
b983ba35bd | |||
54be8a0be2 | |||
dfc77408bd | |||
c17ca8ef18 | |||
06ec486794 | |||
8fe8386591 | |||
a37415c31b | |||
81653d9688 | |||
eeab52a4ff | |||
c33afd89f5 | |||
7e9bd08f60 | |||
ae0ccb4017 | |||
739c350c19 | |||
ba8dc958a3 | |||
e221910e77 | |||
b167109ba1 | |||
602358f8a8 | |||
49a3c8662b | |||
b0925b3878 | |||
654865e21d | |||
c9415c19d3 | |||
4c922709b6 | |||
657061fdce | |||
2f8844ba08 | |||
4b59f00e91 | |||
9e8744a545 | |||
e4a28e5316 | |||
0bba88df03 | |||
8437bae6ef | |||
f48c6791b7 | |||
c2c5e0909a | |||
1cb0cc2975 | |||
99c3cfb83c | |||
1ece1ae829 | |||
c59e120c55 | |||
d2339d6840 | |||
b35cc93420 | |||
8cbba4622c | |||
385da2dae2 | |||
2daf23ab0c | |||
cbf4c05b15 | |||
d3c04b6a39 | |||
4cb3b924cd | |||
a33ce60c66 | |||
24aecf421a | |||
2efce05dc3 | |||
8999ec3c16 | |||
05af6da8d9 | |||
9a4548bae7 | |||
ff578cae54 | |||
22de45235c | |||
76e8a70476 | |||
9cbc7e5f3b | |||
27a7b070db | |||
901cf4c52b | |||
d0fae88114 | |||
17c3103c56 | |||
996d095c54 | |||
d65fac2738 | |||
ce4f5a29fb | |||
baee28c46c | |||
29e70e3e88 | |||
82091b864a | |||
c0c2335ce0 | |||
90fbf12540 | |||
49d849b3ab | |||
27ca23dc00 | |||
54d3544784 | |||
703e42ee4b | |||
29a8d6a554 | |||
2c08ff23c0 | |||
bfdcfa6a05 | |||
9289e577ec | |||
a6d471c759 | |||
01a5d18a53 | |||
929b4f2973 | |||
3b7178cfa4 | |||
e46fa5d52e | |||
a8683102cc | |||
71bcaf99e2 | |||
8b430d7dea | |||
e0ade06d63 | |||
4bd18ec0c7 | |||
2410e320b3 | |||
48a8f4a7fd | |||
4dd6416faf | |||
c1c0d00b88 | |||
d9f726c4d0 | |||
d6e4a130b0 | |||
cfc15a1031 | |||
70f3e8e3a1 | |||
ef978fe411 | |||
f7c1234990 | |||
57f044945f | |||
4caf7044e0 | |||
6f32cddf1c | |||
c530e2cfe3 | |||
fd5dcc5c81 | |||
93dc5a2870 | |||
95529e3253 | |||
344020c926 | |||
5574081c49 | |||
d7f396486e | |||
8fbd84bf78 | |||
7d2dcce175 | |||
dc903e70ac | |||
a9c8212895 | |||
c20ecb6a51 | |||
5253edaacb | |||
017d9f1515 | |||
181b27d881 | |||
63e2a6419d | |||
264017a2bf | |||
e433c115bc | |||
86fd8bb0ac | |||
ab3a5a8259 | |||
a61f0521b8 | |||
537c9755a7 | |||
786b7f18a5 | |||
8f36444c4f | |||
185b2c29e2 | |||
5f08050d8d | |||
64da65b322 | |||
5255d99dc5 | |||
4f2ad11135 | |||
d7afab6d3a | |||
31348dff03 | |||
25e86b6a61 | |||
4efbac6d35 | |||
87069ccf68 | |||
7e45107f51 | |||
0c48b37c31 | |||
7eacffd951 | |||
2a543d6efe | |||
317b29de0f | |||
a463c333dd | |||
ea356004d4 | |||
5c976a7e1a | |||
f964493274 | |||
a4211a4dc3 | |||
563836496a | |||
4ca2c358b1 | |||
0580aab02f | |||
3711811b1d | |||
65b89d16ee | |||
931746bc6d | |||
c81dddb45c | |||
fe6d09ae61 | |||
ed70c70ea3 | |||
f0d4e14557 | |||
2ccee3def6 | |||
b92adec8e8 | |||
56f738ae9b | |||
72d3a30c63 | |||
c9b45adeeb | |||
5a6c81b051 | |||
51cd22ce56 | |||
5ed704ec8c | |||
4abf6336ec | |||
0e163fce18 | |||
96b6f475dd | |||
c410f5d020 | |||
bb8c697ee0 | |||
b9e96b17de | |||
923797fea4 | |||
cd9e60c76c | |||
93b38bea5d | |||
d0d93b92b1 | |||
89efcf1ce5 | |||
c664b0e683 | |||
d69ff0cbbb | |||
1af090b57d | |||
3dad944485 | |||
105a40f53a | |||
bbe9bd9684 | |||
4f65af0e25 | |||
d79ced3292 | |||
ab40644669 | |||
5d60def02c | |||
ea8489fce2 | |||
1b20639a43 | |||
b72af8f1ed | |||
9090bf02e7 | |||
7d648418b8 | |||
89be30fa7d | |||
f8ecb84c02 | |||
5f036d2bcc | |||
380170038e | |||
220a47627b | |||
beb89f68b4 | |||
390b495ff3 | |||
3a0e1fc070 | |||
6b7de1a030 | |||
5265631d15 | |||
2832e7b9f9 | |||
3a7dd7e367 | |||
223c19224b | |||
f1f6cc10c7 | |||
3209b49033 | |||
1e4277d2d1 | |||
9b945daaf1 | |||
9c1352eb57 | |||
7a0b011dd5 | |||
63e835cbcc | |||
94b5edeb53 | |||
ab7e6006d6 | |||
18bfcdd05c | |||
71d63ed72e | |||
d75c40734a | |||
5b23c3f26f | |||
00efdc84ba | |||
91a61da9b1 | |||
ef9b636e2d | |||
2709c0009a | |||
dd7e8f5f64 | |||
d2a68364c4 | |||
7e1081139d | |||
18473cf498 | |||
4df417d059 | |||
5d80a9178b | |||
8a25d3a71a | |||
d10f8e1d43 | |||
14cc317ba4 | |||
e1957c6ebd | |||
8cd5a992bf | |||
947f0b23cc | |||
f780504d12 | |||
bfc072addf | |||
2a18da257c | |||
6e01e8c1c8 | |||
9f659bf07f | |||
35c4bc20d9 | |||
218dc2ccda | |||
827cbcd37c | |||
cb7a1c1cbf | |||
7878958c0d | |||
ce036244c9 | |||
48cf1e413c | |||
97460585d9 | |||
f745847ef7 | |||
6549aef245 | |||
50376faa7b | |||
4b61c6b669 | |||
79d64c4954 | |||
74cd5abdd1 | |||
28c3f12104 | |||
c884819135 | |||
05921a9a7a | |||
d0215a58e7 | |||
937e7b7d7c | |||
aee8ef661a | |||
2e0b6e7757 | |||
941767127c | |||
74d8d77626 | |||
fd4ea8ef5c | |||
1066cbd152 | |||
6ef00b03a2 | |||
9140561059 | |||
77af974b40 | |||
4934d49274 | |||
358c328d69 | |||
4aaafdd289 | |||
66b108d142 | |||
e0ff920001 | |||
face83c7ec | |||
1db83e31a2 | |||
a1b9cb2a34 | |||
3a4fd5ca59 | |||
c17daa9f89 | |||
bd29cf3d3a | |||
31bff69151 | |||
ba4f826738 | |||
de60a3fb93 | |||
21d5daa4ac | |||
290e015c6c | |||
1b7c791d60 | |||
bbe4466fd9 | |||
08133c4d1a | |||
76a7983b23 | |||
8041b7305e | |||
3ec8c25cd0 | |||
671af2b1c0 | |||
6f41f0e377 | |||
2c9b638065 | |||
a7347d9a6d | |||
f8c688d746 | |||
c9fadda543 | |||
30fb0956df | |||
3a765bd5e1 | |||
26c52a5ea6 | |||
c3372e87be | |||
b0a1d667b0 | |||
e1d5402238 | |||
3d1cfbfc74 | |||
37ca558103 | |||
eed74a558f | |||
2acd76f346 | |||
b81a6a6bb3 | |||
0fbfc4b81b | |||
c06170cc8e | |||
614856da25 | |||
05bdf4eaf3 | |||
6774bd50b0 | |||
31c1f3255e | |||
21d93c140d | |||
f1c8520146 | |||
096827c284 | |||
6565d9e33e | |||
f375ec8440 | |||
518369d78c | |||
30bad5c492 | |||
3fefe271ec | |||
6428f1d051 | |||
7e1b21daac | |||
cb3f30c600 | |||
f3e024bece | |||
31d2ab4aff | |||
eb17212858 | |||
4dd4b5c538 | |||
6120e5aaea | |||
2eaa81b236 | |||
81ce2a4b26 | |||
5dd80d3777 | |||
beeee69bc9 | |||
9bf28d0b69 | |||
c0ce15dfb2 | |||
b9bcdc7158 | |||
4ff0203987 | |||
b5f882cc98 | |||
2e8fc0d4c3 | |||
dacaf5a400 | |||
24cde76a15 | |||
1aa1361510 | |||
fe470ae5ad | |||
3a8c2381f7 | |||
c85b80c2b6 | |||
2b981012a6 | |||
6ccc0bfffb | |||
c8e7eb1eb3 | |||
24f60a54f4 | |||
42c02f5892 | |||
ebede26ebf | |||
d940ce497e | |||
05ff90b692 | |||
1d9b737e05 | |||
60dc62dc9e | |||
0f90effc66 | |||
464dd985e3 | |||
c07a442854 | |||
cd3aa153a4 | |||
9b294976a2 | |||
5313c2cb8b | |||
5f09cbdb63 | |||
4cefa9b49b | |||
f86bd6190a | |||
e5452ddfd6 | |||
d06980dfa7 | |||
66785cc05c | |||
05a38612b0 | |||
d27f4bae39 | |||
8d8c2f6ffe | |||
51d3cb951d | |||
e74b1736a1 | |||
f07c1ceaa5 | |||
63b2206ad0 | |||
27feead2f8 | |||
c782195662 | |||
0f621c2c7d | |||
a9e4574261 | |||
0229c386c5 | |||
a7b3e33078 | |||
e19a64c7ef | |||
1cb4ad8de9 | |||
6ed068a71a | |||
708e6c18b0 | |||
b943890484 | |||
a1125ad4df | |||
a8b150c595 | |||
665cbcec4b | |||
7c600440f7 | |||
e0c6f556e8 | |||
de23687d16 | |||
4cea74c73b | |||
a921d8be9d | |||
094f716bf2 | |||
7d761fe3c1 | |||
cf35d8f3d7 | |||
4bb6b67188 | |||
819b18e7ba | |||
19849db573 | |||
3d4ceb292c | |||
f5a37c6c6c | |||
32c927b53f | |||
5ffc0d13a2 | |||
112627e8b2 | |||
37c1e3c218 | |||
06e9ebebd5 |
18
.buildkite/download-images.sh
Normal file
18
.buildkite/download-images.sh
Normal file
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
|
||||
mkdir -p images
|
||||
cd images
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg
|
||||
|
||||
cd -
|
38
.buildkite/run-amd-test.sh
Normal file
38
.buildkite/run-amd-test.sh
Normal file
@ -0,0 +1,38 @@
|
||||
# This script build the ROCm docker image and run the API server inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
|
||||
# Print ROCm version
|
||||
rocminfo
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t rocm -f Dockerfile.rocm .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f rocm || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image
|
||||
docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server &
|
||||
|
||||
# Wait for the server to start
|
||||
wait_for_server_to_start() {
|
||||
timeout=300
|
||||
counter=0
|
||||
|
||||
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
|
||||
sleep 1
|
||||
counter=$((counter + 1))
|
||||
if [ $counter -ge $timeout ]; then
|
||||
echo "Timeout after $timeout seconds"
|
||||
break
|
||||
fi
|
||||
done
|
||||
}
|
||||
wait_for_server_to_start
|
||||
|
||||
# Test a simple prompt
|
||||
curl -X POST -H "Content-Type: application/json" \
|
||||
localhost:8000/generate \
|
||||
-d '{"prompt": "San Francisco is a"}'
|
72
.buildkite/run-benchmarks.sh
Normal file
72
.buildkite/run-benchmarks.sh
Normal file
@ -0,0 +1,72 @@
|
||||
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# cd into parent directory of this file
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# run python-based benchmarks and upload the result to buildkite
|
||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
||||
bench_latency_exit_code=$?
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
||||
bench_throughput_exit_code=$?
|
||||
|
||||
# run server-based benchmarks and upload the result to buildkite
|
||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
||||
server_pid=$!
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
# wait for server to start, timeout after 600 seconds
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer meta-llama/Llama-2-7b-chat-hf \
|
||||
--save-result \
|
||||
2>&1 | tee benchmark_serving.txt
|
||||
bench_serving_exit_code=$?
|
||||
kill $server_pid
|
||||
|
||||
# write the results into a markdown file
|
||||
echo "### Latency Benchmarks" >> benchmark_results.md
|
||||
sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
|
||||
echo "" >> benchmark_results.md
|
||||
sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
|
||||
|
||||
echo "### Throughput Benchmarks" >> benchmark_results.md
|
||||
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
|
||||
echo "" >> benchmark_results.md
|
||||
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
|
||||
|
||||
echo "### Serving Benchmarks" >> benchmark_results.md
|
||||
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
||||
echo "" >> benchmark_results.md
|
||||
echo '```' >> benchmark_results.md
|
||||
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines
|
||||
echo '```' >> benchmark_results.md
|
||||
|
||||
# upload the results to buildkite
|
||||
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
||||
|
||||
# exit with the exit code of the benchmarks
|
||||
if [ $bench_latency_exit_code -ne 0 ]; then
|
||||
exit $bench_latency_exit_code
|
||||
fi
|
||||
|
||||
if [ $bench_throughput_exit_code -ne 0 ]; then
|
||||
exit $bench_throughput_exit_code
|
||||
fi
|
||||
|
||||
if [ $bench_serving_exit_code -ne 0 ]; then
|
||||
exit $bench_serving_exit_code
|
||||
fi
|
||||
|
||||
/workspace/buildkite-agent artifact upload openai-*.json
|
97
.buildkite/test-pipeline.yaml
Normal file
97
.buildkite/test-pipeline.yaml
Normal file
@ -0,0 +1,97 @@
|
||||
# In this file, you can add more tests to run either by adding a new step or
|
||||
# adding a new command to an existing step. See different options here for examples.
|
||||
# This script will be feed into Jinja template in `test-template.j2` to generate
|
||||
# the final pipeline yaml file.
|
||||
|
||||
steps:
|
||||
- label: Regression Test
|
||||
command: pytest -v -s test_regression.py
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: AsyncEngine Test
|
||||
command: pytest -v -s async_engine
|
||||
|
||||
- label: Basic Correctness Test
|
||||
command: pytest -v -s basic_correctness
|
||||
|
||||
- label: Core Test
|
||||
command: pytest -v -s core
|
||||
|
||||
- label: Distributed Comm Ops Test
|
||||
command: pytest -v -s test_comm_ops.py
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
num_gpus: 2 # only support 1 or 2 for now.
|
||||
|
||||
- label: Distributed Tests
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
num_gpus: 2 # only support 1 or 2 for now.
|
||||
commands:
|
||||
- pytest -v -s test_pynccl.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
||||
|
||||
- label: Entrypoints Test
|
||||
command: pytest -v -s entrypoints
|
||||
|
||||
- label: Examples Test
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
commands:
|
||||
# install aws cli for llava_example.py
|
||||
- pip install awscli
|
||||
- python3 offline_inference.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 llava_example.py
|
||||
|
||||
- label: Kernels Test %N
|
||||
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Models Test
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
|
||||
|
||||
- label: Llava Test
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- pytest -v -s models/test_llava.py
|
||||
|
||||
- label: Prefix Caching Test
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test
|
||||
command: pytest -v -s samplers
|
||||
|
||||
- label: LogitsProcessor Test
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
|
||||
- label: Worker Test
|
||||
command: pytest -v -s worker
|
||||
|
||||
- label: Speculative decoding tests
|
||||
command: pytest -v -s spec_decode
|
||||
|
||||
- label: LoRA Test %N
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Metrics Test
|
||||
command: pytest -v -s metrics
|
||||
|
||||
- label: Benchmarks
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
commands:
|
||||
- pip install aiohttp
|
||||
- bash run-benchmarks.sh
|
||||
|
||||
- label: Documentation Build
|
||||
working_dir: "/vllm-workspace/docs"
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r requirements-docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
66
.buildkite/test-template.j2
Normal file
66
.buildkite/test-template.j2
Normal file
@ -0,0 +1,66 @@
|
||||
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
|
||||
{% set default_num_gpu = 1 %}
|
||||
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||
|
||||
steps:
|
||||
- label: "AMD Test"
|
||||
agents:
|
||||
queue: amd
|
||||
command: bash .buildkite/run-amd-test.sh
|
||||
|
||||
- label: ":docker: build image"
|
||||
commands:
|
||||
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||
- "docker push {{ docker_image }}"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
- wait
|
||||
|
||||
{% for step in steps %}
|
||||
- label: "{{ step.label }}"
|
||||
agents:
|
||||
queue: kubernetes
|
||||
soft_fail: {{ step.soft_fail or false }}
|
||||
{% if step.parallelism %}
|
||||
parallelism: {{ step.parallelism }}
|
||||
{% endif %}
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
plugins:
|
||||
- kubernetes:
|
||||
podSpec:
|
||||
volumes:
|
||||
- name: dshm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
containers:
|
||||
- image: "{{ docker_image }}"
|
||||
command: ["bash"]
|
||||
args:
|
||||
- '-c'
|
||||
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
|
||||
{% if not step.no_gpu %}
|
||||
resources:
|
||||
requests:
|
||||
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||
limits:
|
||||
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||
{% endif %}
|
||||
env:
|
||||
- name: VLLM_USAGE_SOURCE
|
||||
value: ci-test
|
||||
- name: HF_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: hf-token-secret
|
||||
key: token
|
||||
volumeMounts:
|
||||
- mountPath: /dev/shm
|
||||
name: dshm
|
||||
{% endfor %}
|
1
.dockerignore
Normal file
1
.dockerignore
Normal file
@ -0,0 +1 @@
|
||||
vllm/*.so
|
22
.github/ISSUE_TEMPLATE/100-documentation.yml
vendored
Normal file
22
.github/ISSUE_TEMPLATE/100-documentation.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: 📚 Documentation
|
||||
description: Report an issue related to https://docs.vllm.ai/
|
||||
title: "[Doc]: "
|
||||
labels: ["documentation"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 📚 The doc issue
|
||||
description: >
|
||||
A clear and concise description of what content in https://docs.vllm.ai/ is an issue.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Suggest a potential alternative/fix
|
||||
description: >
|
||||
Tell us how we could improve the documentation in this regard.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
39
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
Normal file
39
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
name: 🛠️ Installation
|
||||
description: Report an issue here when you hit errors during installation.
|
||||
title: "[Installation]: "
|
||||
labels: ["installation"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your current environment
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
value: |
|
||||
```text
|
||||
The output of `python collect_env.py`
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: How you are installing vllm
|
||||
description: |
|
||||
Paste the full command you are trying to execute.
|
||||
value: |
|
||||
```sh
|
||||
pip install -vvv vllm
|
||||
```
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
37
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
Normal file
37
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
Normal file
@ -0,0 +1,37 @@
|
||||
name: 💻 Usage
|
||||
description: Raise an issue here if you don't know how to use vllm.
|
||||
title: "[Usage]: "
|
||||
labels: ["usage"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your current environment
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
value: |
|
||||
```text
|
||||
The output of `python collect_env.py`
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: How would you like to use vllm
|
||||
description: |
|
||||
A detailed description of how you want to use vllm.
|
||||
value: |
|
||||
I want to run inference of a [specific model](put link here). I don't know how to integrate it with vllm.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
81
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
Normal file
81
.github/ISSUE_TEMPLATE/400-bug report.yml
vendored
Normal file
@ -0,0 +1,81 @@
|
||||
name: 🐛 Bug report
|
||||
description: Raise an issue here if you find a bug.
|
||||
title: "[Bug]: "
|
||||
labels: ["bug"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your current environment
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
value: |
|
||||
```text
|
||||
The output of `python collect_env.py`
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 🐛 Describe the bug
|
||||
description: |
|
||||
Please provide a clear and concise description of what the bug is.
|
||||
|
||||
If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
llm = LLM(model="facebook/opt-125m")
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
```
|
||||
|
||||
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
|
||||
|
||||
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
||||
placeholder: |
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
```python
|
||||
# Sample code to reproduce the problem
|
||||
```
|
||||
|
||||
```
|
||||
The error message you got, with the full traceback.
|
||||
```
|
||||
validations:
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
⚠️ Please separate bugs of `transformers` implementation or usage from bugs of `vllm`. If you think anything is wrong with the models' output:
|
||||
|
||||
- Try the counterpart of `transformers` first. If the error appears, please go to [their issues](https://github.com/huggingface/transformers/issues?q=is%3Aissue+is%3Aopen+sort%3Aupdated-desc).
|
||||
|
||||
- If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect.
|
||||
|
||||
Thanks for contributing 🎉!
|
31
.github/ISSUE_TEMPLATE/500-feature request.yml
vendored
Normal file
31
.github/ISSUE_TEMPLATE/500-feature request.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: 🚀 Feature request
|
||||
description: Submit a proposal/request for a new vllm feature
|
||||
title: "[Feature]: "
|
||||
labels: ["feature request"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 🚀 The feature, motivation and pitch
|
||||
description: >
|
||||
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Alternatives
|
||||
description: >
|
||||
A description of any alternative solutions or features you've considered, if any.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: >
|
||||
Add any other context or screenshots about the feature request.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
33
.github/ISSUE_TEMPLATE/600-new model.yml
vendored
Normal file
33
.github/ISSUE_TEMPLATE/600-new model.yml
vendored
Normal file
@ -0,0 +1,33 @@
|
||||
name: 🤗 Support request for a new model from huggingface
|
||||
description: Submit a proposal/request for a new model from huggingface
|
||||
title: "[New Model]: "
|
||||
labels: ["new model"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/models/adding_model.html first to understand how to add a new model.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The model to consider.
|
||||
description: >
|
||||
A huggingface url, pointing to the model, e.g. https://huggingface.co/openai-community/gpt2 .
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The closest model vllm already supports.
|
||||
description: >
|
||||
Here is the list of models already supported by vllm: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models . Which model is the most similar to the model you want to add support for?
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: What's your difficulty of supporting the model you want?
|
||||
description: >
|
||||
For example, any new operators or new architecture?
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
51
.github/ISSUE_TEMPLATE/700-performance discussion.yml
vendored
Normal file
51
.github/ISSUE_TEMPLATE/700-performance discussion.yml
vendored
Normal file
@ -0,0 +1,51 @@
|
||||
name: ⚡ Discussion on the performance of vllm
|
||||
description: Submit a proposal/discussion about the performance of vllm
|
||||
title: "[Performance]: "
|
||||
labels: ["performance"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Proposal to improve performance
|
||||
description: >
|
||||
How do you plan to improve vllm's performance?
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Report of performance regression
|
||||
description: >
|
||||
Please provide detailed description of performance comparison to confirm the regression. You may want to run the benchmark script at https://github.com/vllm-project/vllm/tree/main/benchmarks .
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Misc discussion on performance
|
||||
description: >
|
||||
Anything about the performance.
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Your current environment (if you think it is necessary)
|
||||
description: |
|
||||
Please run the following and paste the output below.
|
||||
```sh
|
||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||
python collect_env.py
|
||||
```
|
||||
value: |
|
||||
```text
|
||||
The output of `python collect_env.py`
|
||||
```
|
||||
validations:
|
||||
required: false
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
21
.github/ISSUE_TEMPLATE/800-misc discussion.yml
vendored
Normal file
21
.github/ISSUE_TEMPLATE/800-misc discussion.yml
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
name: 🎲 Misc/random discussions that do not fit into the above categories.
|
||||
description: Submit a discussion as you like. Note that developers are heavily overloaded and we mainly rely on community users to answer these issues.
|
||||
title: "[Misc]: "
|
||||
labels: ["misc"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Anything you want to discuss about vllm.
|
||||
description: >
|
||||
Anything you want to discuss about vllm.
|
||||
validations:
|
||||
required: true
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
1
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@ -0,0 +1 @@
|
||||
blank_issues_enabled: false
|
64
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
64
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
FILL IN THE PR DESCRIPTION HERE
|
||||
|
||||
FIX #xxxx (*link existing issues this PR will resolve*)
|
||||
|
||||
**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE**
|
||||
|
||||
---
|
||||
|
||||
<details>
|
||||
<!-- inside this <details> section, markdown rendering does not work, so we use raw html here. -->
|
||||
<summary><b> PR Checklist (Click to Expand) </b></summary>
|
||||
|
||||
<p>Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.</p>
|
||||
|
||||
<h3>PR Title and Classification</h3>
|
||||
<p>Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:</p>
|
||||
<ul>
|
||||
<li><code>[Bugfix]</code> for bug fixes.</li>
|
||||
<li><code>[CI/Build]</code> for build or continuous integration improvements.</li>
|
||||
<li><code>[Doc]</code> for documentation fixes and improvements.</li>
|
||||
<li><code>[Model]</code> for adding a new model or improving an existing model. Model name should appear in the title.</li>
|
||||
<li><code>[Frontend]</code> For changes on the vLLM frontend (e.g., OpenAI API server, <code>LLM</code> class, etc.) </li>
|
||||
<li><code>[Kernel]</code> for changes affecting CUDA kernels or other compute kernels.</li>
|
||||
<li><code>[Core]</code> for changes in the core vLLM logic (e.g., <code>LLMEngine</code>, <code>AsyncLLMEngine</code>, <code>Scheduler</code>, etc.)</li>
|
||||
<li><code>[Hardware][Vendor]</code> for hardware-specific changes. Vendor name should appear in the prefix (e.g., <code>[Hardware][AMD]</code>).</li>
|
||||
<li><code>[Misc]</code> for PRs that do not fit the above categories. Please use this sparingly.</li>
|
||||
</ul>
|
||||
<p><strong>Note:</strong> If the PR spans more than one category, please include all relevant prefixes.</p>
|
||||
|
||||
<h3>Code Quality</h3>
|
||||
|
||||
<p>The PR need to meet the following code quality standards:</p>
|
||||
|
||||
<ul>
|
||||
<li>We adhere to <a href="https://google.github.io/styleguide/pyguide.html">Google Python style guide</a> and <a href="https://google.github.io/styleguide/cppguide.html">Google C++ style guide</a>.</li>
|
||||
<li>Pass all linter checks. Please use <a href="https://github.com/vllm-project/vllm/blob/main/format.sh"><code>format.sh</code></a> to format your code.</li>
|
||||
<li>The code need to be well-documented to ensure future contributors can easily understand the code.</li>
|
||||
<li>Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.</li>
|
||||
<li>Please add documentation to <code>docs/source/</code> if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.</li>
|
||||
</ul>
|
||||
|
||||
<h3>Notes for Large Changes</h3>
|
||||
<p>Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with <code>rfc-required</code> and might not go through the PR.</p>
|
||||
|
||||
<h3>What to Expect for the Reviews</h3>
|
||||
|
||||
<p>The goal of the vLLM team is to be a <i>transparent reviewing machine</i>. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process: </p>
|
||||
|
||||
<ul>
|
||||
<li> After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.</li>
|
||||
<li> After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.</li>
|
||||
<li> After the review, the reviewer will put an <code> action-required</code> label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.</li>
|
||||
<li> Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
|
||||
</li>
|
||||
</ul>
|
||||
|
||||
<h3>Thank You</h3>
|
||||
|
||||
<p> Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone! </p>
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
|
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: ['2.1.0']
|
||||
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
|
@ -1,4 +1,4 @@
|
||||
name: pylint
|
||||
name: ruff
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
@ -11,7 +11,7 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@ -25,7 +25,13 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint==2.8.2
|
||||
- name: Analysing the code with pylint
|
||||
pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1 isort==5.13.2
|
||||
- name: Analysing the code with ruff
|
||||
run: |
|
||||
pylint vllm tests
|
||||
ruff .
|
||||
- name: Spelling check with codespell
|
||||
run: |
|
||||
codespell --toml pyproject.toml
|
||||
- name: Run isort
|
||||
run: |
|
||||
isort . --check-only
|
2
.github/workflows/scripts/build.sh
vendored
2
.github/workflows/scripts/build.sh
vendored
@ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
# Make sure punica is built for the release (for LoRA)
|
||||
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
# Build
|
||||
$python_executable setup.py bdist_wheel --dist-dir=dist
|
||||
|
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive vllm tests
|
||||
yapf --diff --recursive .
|
||||
|
7
.gitignore
vendored
7
.gitignore
vendored
@ -177,3 +177,10 @@ _build/
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
|
||||
# Benchmark dataset
|
||||
*.json
|
||||
|
434
.pylintrc
434
.pylintrc
@ -1,434 +0,0 @@
|
||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=docs
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
duplicate-code,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat-in-sequence,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
inconsistent-return-statements,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
logging-fstring-interpolation, # added by vLLM
|
||||
logging-not-lazy, # added by vLLM
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-class-docstring, # TODO (vLLM): enable
|
||||
missing-function-docstring,
|
||||
missing-module-docstring, # TODO (vLLM): enable
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
too-few-public-methods,
|
||||
too-many-ancestors,
|
||||
too-many-arguments,
|
||||
too-many-boolean-expressions,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
unspecified-encoding,
|
||||
useless-else-on-loop,
|
||||
useless-object-inheritance,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=80
|
||||
|
||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=StandardError,
|
||||
Exception,
|
||||
BaseException
|
1
.yapfignore
Normal file
1
.yapfignore
Normal file
@ -0,0 +1 @@
|
||||
collect_env.py
|
286
CMakeLists.txt
Normal file
286
CMakeLists.txt
Normal file
@ -0,0 +1,286 @@
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
project(vllm_extensions LANGUAGES CXX)
|
||||
|
||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
|
||||
#
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
#
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
|
||||
|
||||
# Supported NVIDIA architectures.
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
|
||||
|
||||
# Supported AMD GPU architectures.
|
||||
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
|
||||
|
||||
#
|
||||
# Supported/expected torch versions for CUDA/ROCm.
|
||||
#
|
||||
# Currently, having an incorrect pytorch version results in a warning
|
||||
# rather than an error.
|
||||
#
|
||||
# Note: the CUDA torch version is derived from pyproject.toml and various
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions.
|
||||
#
|
||||
if (VLLM_PYTHON_EXECUTABLE)
|
||||
find_python_from_executable(${VLLM_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}")
|
||||
else()
|
||||
message(FATAL_ERROR
|
||||
"Please set VLLM_PYTHON_EXECUTABLE to the path of the desired python version"
|
||||
" before running cmake configure.")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Update cmake's `CMAKE_PREFIX_PATH` with torch location.
|
||||
#
|
||||
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
||||
|
||||
# Ensure the 'nvcc' command is in the PATH
|
||||
find_program(NVCC_EXECUTABLE nvcc)
|
||||
if (CUDA_FOUND AND NOT NVCC_EXECUTABLE)
|
||||
message(FATAL_ERROR "nvcc not found")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Import torch cmake configuration.
|
||||
# Torch also imports CUDA (and partially HIP) languages with some customizations,
|
||||
# so there is no need to do this explicitly with check_language/enable_language,
|
||||
# etc.
|
||||
#
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
#
|
||||
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
|
||||
# `libtorch_python.so` for linking against an extension. Torch's cmake
|
||||
# configuration does not include this library (presumably since the cmake
|
||||
# config is used for standalone C++ binaries that link against torch).
|
||||
# The `libtorch_python.so` library defines some of the glue code between
|
||||
# torch/python via pybind and is required by VLLM extensions for this
|
||||
# reason. So, add it by manually with `find_library` using torch's
|
||||
# installed library path.
|
||||
#
|
||||
find_library(torch_python_LIBRARY torch_python PATHS
|
||||
"${TORCH_INSTALL_PREFIX}/lib")
|
||||
|
||||
#
|
||||
# Set up GPU language and check the torch version and warn if it isn't
|
||||
# what is expected.
|
||||
#
|
||||
if (NOT HIP_FOUND AND CUDA_FOUND)
|
||||
set(VLLM_GPU_LANG "CUDA")
|
||||
|
||||
if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
|
||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
|
||||
"expected for CUDA build, saw ${Torch_VERSION} instead.")
|
||||
endif()
|
||||
elseif(HIP_FOUND)
|
||||
set(VLLM_GPU_LANG "HIP")
|
||||
|
||||
# Importing torch recognizes and sets up some HIP/ROCm configuration but does
|
||||
# not let cmake recognize .hip files. In order to get cmake to understand the
|
||||
# .hip extension automatically, HIP must be enabled explicitly.
|
||||
enable_language(HIP)
|
||||
|
||||
# ROCm 5.x
|
||||
if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
|
||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
|
||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
|
||||
"expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
|
||||
endif()
|
||||
|
||||
# ROCm 6.x
|
||||
if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
|
||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
|
||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
|
||||
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
|
||||
endif()
|
||||
else()
|
||||
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||
# the supported versions for the current language.
|
||||
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
|
||||
#
|
||||
override_gpu_arches(VLLM_GPU_ARCHES
|
||||
${VLLM_GPU_LANG}
|
||||
"${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
|
||||
|
||||
#
|
||||
# Query torch for additional GPU compilation flags for the given
|
||||
# `VLLM_GPU_LANG`.
|
||||
# The final set of arches is stored in `VLLM_GPU_FLAGS`.
|
||||
#
|
||||
get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})
|
||||
|
||||
#
|
||||
# Set nvcc parallelism.
|
||||
#
|
||||
if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
|
||||
#
|
||||
# _C extension
|
||||
#
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/attention_kernels.cu"
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/moe_align_block_size_kernels.cu"
|
||||
"csrc/pybind.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
||||
"csrc/custom_all_reduce.cu")
|
||||
endif()
|
||||
|
||||
define_gpu_extension_target(
|
||||
_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
# _moe_C extension
|
||||
#
|
||||
|
||||
set(VLLM_MOE_EXT_SRC
|
||||
"csrc/moe/moe_ops.cpp"
|
||||
"csrc/moe/topk_softmax_kernels.cu")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_moe_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_MOE_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
# _punica_C extension
|
||||
#
|
||||
|
||||
set(VLLM_PUNICA_EXT_SRC
|
||||
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu"
|
||||
"csrc/punica/punica_ops.cc")
|
||||
|
||||
#
|
||||
# Copy GPU compilation flags+update for punica
|
||||
#
|
||||
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
|
||||
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__")
|
||||
|
||||
#
|
||||
# Filter out CUDA architectures < 8.0 for punica.
|
||||
#
|
||||
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
||||
set(VLLM_PUNICA_GPU_ARCHES)
|
||||
foreach(ARCH ${VLLM_GPU_ARCHES})
|
||||
string_to_ver(CODE_VER ${ARCH})
|
||||
if (CODE_VER GREATER_EQUAL 8.0)
|
||||
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
endif()
|
||||
|
||||
if (VLLM_PUNICA_GPU_ARCHES)
|
||||
define_gpu_extension_target(
|
||||
_punica_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
||||
WITH_SOABI)
|
||||
else()
|
||||
message(WARNING "Unable to create _punica_C target because none of the "
|
||||
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
# built based on platform/architecture. This is the same logic that
|
||||
# setup.py uses to select which extensions should be built and should
|
||||
# be kept in sync.
|
||||
#
|
||||
# The `default` target makes direct use of cmake easier since knowledge
|
||||
# of which extensions are supported has been factored in, e.g.
|
||||
#
|
||||
# mkdir build && cd build
|
||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
||||
# cmake --build . --target default
|
||||
#
|
||||
add_custom_target(default)
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
endif()
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
|
||||
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
||||
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
||||
# there are supported target arches.
|
||||
if (VLLM_PUNICA_GPU_ARCHES AND
|
||||
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
|
||||
message(STATUS "Enabling punica extension.")
|
||||
add_dependencies(default _punica_C)
|
||||
endif()
|
||||
endif()
|
@ -45,31 +45,9 @@ pytest tests/
|
||||
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
|
||||
If not, please file a new issue, providing as much relevant information as possible.
|
||||
|
||||
### Coding Style Guide
|
||||
### Pull Requests & Code Reviews
|
||||
|
||||
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
|
||||
We include a formatting script [`format.sh`](./format.sh) to format the code.
|
||||
|
||||
### Pull Requests
|
||||
|
||||
When submitting a pull request:
|
||||
|
||||
1. Make sure your code has been rebased on top of the latest commit on the main branch.
|
||||
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
|
||||
3. Include a detailed description of the changes in the pull request.
|
||||
Explain why you made the changes you did.
|
||||
If your pull request fixes an open issue, please include a reference to it in the description.
|
||||
|
||||
### Code Reviews
|
||||
|
||||
All submissions, including submissions by project members, require a code review.
|
||||
To make the review process as smooth as possible, please:
|
||||
|
||||
1. Keep your changes as concise as possible.
|
||||
If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
|
||||
2. Respond to all comments within a reasonable time frame.
|
||||
If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
|
||||
Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE.md) for detailed guide for contribution.
|
||||
|
||||
### Thank You
|
||||
|
||||
|
100
Dockerfile
100
Dockerfile
@ -1,7 +1,17 @@
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip
|
||||
&& apt-get install -y python3-pip git
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-12.1/compat/
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
@ -14,34 +24,87 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
COPY requirements-dev.txt requirements-dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements-dev.txt
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
|
||||
# image to build pytorch extensions
|
||||
|
||||
#################### EXTENSION BUILD IMAGE ####################
|
||||
FROM dev AS build
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
# install compiler cache to speed up compilation leveraging local or remote caching
|
||||
RUN apt-get update -y && apt-get install -y ccache
|
||||
|
||||
# copy input files
|
||||
COPY csrc csrc
|
||||
COPY setup.py setup.py
|
||||
COPY cmake cmake
|
||||
COPY CMakeLists.txt CMakeLists.txt
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm/__init__.py vllm/__init__.py
|
||||
|
||||
# cuda arch list used by torch
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# max jobs used by Ninja to build extensions
|
||||
ENV MAX_JOBS=$max_jobs
|
||||
RUN python3 setup.py build_ext --inplace
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
# make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
python3 setup.py build_ext --inplace
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
FROM dev as flash-attn-builder
|
||||
# max jobs used for build
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# flash attention version
|
||||
ARG flash_attn_version=v2.5.6
|
||||
ENV FLASH_ATTN_VERSION=${flash_attn_version}
|
||||
|
||||
WORKDIR /usr/src/flash-attention-v2
|
||||
|
||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
||||
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
||||
--no-build-isolation --no-deps --no-cache-dir
|
||||
|
||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
# image to run unit testing suite
|
||||
FROM dev AS test
|
||||
|
||||
# copy pytorch extensions separately to avoid having to rebuild
|
||||
# when python code changes
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY tests tests
|
||||
COPY vllm vllm
|
||||
WORKDIR /vllm-workspace
|
||||
# ADD is used to preserve directory structure
|
||||
ADD . /vllm-workspace/
|
||||
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
|
||||
# Install flash attention (from pre-built wheel)
|
||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||
# ignore build dependencies installation because we are using pre-complied extensions
|
||||
RUN rm pyproject.toml
|
||||
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "pytest", "tests"]
|
||||
|
||||
# use CUDA base as CUDA runtime dependencies are already installed via pip
|
||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
|
||||
#################### RUNTIME BASE IMAGE ####################
|
||||
# We used base cuda image because pytorch installs its own cuda libraries.
|
||||
# However pynccl depends on cuda libraries so we had to switch to the runtime image
|
||||
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
|
||||
|
||||
# libnccl required for ray
|
||||
RUN apt-get update -y \
|
||||
@ -52,21 +115,24 @@ COPY requirements.txt requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements.txt
|
||||
|
||||
FROM vllm-base AS vllm
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
# Install flash attention (from pre-built wheel)
|
||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
||||
|
||||
EXPOSE 8000
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
|
||||
#################### RUNTIME BASE IMAGE ####################
|
||||
|
||||
|
||||
#################### OPENAI API SERVER ####################
|
||||
# openai api server alternative
|
||||
FROM vllm-base AS vllm-openai
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate fschat
|
||||
pip install accelerate hf_transfer modelscope
|
||||
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
#################### OPENAI API SERVER ####################
|
||||
|
95
Dockerfile.rocm
Normal file
95
Dockerfile.rocm
Normal file
@ -0,0 +1,95 @@
|
||||
# default base image
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
RUN echo "Base image is $BASE_IMAGE"
|
||||
|
||||
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
||||
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||
|
||||
ARG FA_BRANCH="3d2b6f5"
|
||||
RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||
|
||||
# whether to build flash-attention
|
||||
# if 0, will not build flash attention
|
||||
# this is useful for gfx target where flash-attention is not supported
|
||||
# In that case, we need to use the python reference attention implementation in vllm
|
||||
ARG BUILD_FA="1"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
nvidia-cuda-toolkit \
|
||||
tmux \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
ARG APP_MOUNT=/app
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
# Install ROCm flash-attention
|
||||
RUN if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout ${FA_BRANCH} \
|
||||
&& git submodule update --init \
|
||||
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
||||
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
||||
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..; \
|
||||
fi
|
||||
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually removed it so that later steps of numpy upgrade can continue
|
||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
||||
|
||||
RUN cd /app \
|
||||
&& cd vllm \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& if [ "$BUILD_FA" = "1" ]; then \
|
||||
bash patch_xformers.rocm.sh; fi \
|
||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
|
||||
|
||||
CMD ["/bin/bash"]
|
@ -1,4 +1,6 @@
|
||||
include LICENSE
|
||||
include requirements.txt
|
||||
include CMakeLists.txt
|
||||
|
||||
recursive-include cmake *
|
||||
recursive-include csrc *
|
||||
|
39
README.md
39
README.md
@ -10,13 +10,25 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
**The Third vLLM Bay Area Meetup (April 2nd 6pm-8:30pm PT)**
|
||||
|
||||
We are thrilled to announce our third vLLM Meetup!
|
||||
The vLLM team will share recent updates and roadmap.
|
||||
We will also have vLLM collaborators from Roblox coming up to the stage to discuss their experience in deploying LLMs with vLLM.
|
||||
Please register [here](https://robloxandvllmmeetup2024.splashthat.com/) and join us!
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||
- [2024/01] Added ROCm 6.0 support to vLLM.
|
||||
- [2023/12] Added ROCm 5.7 support to vLLM.
|
||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||
@ -26,7 +38,7 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
vLLM is fast with:
|
||||
@ -34,6 +46,8 @@ vLLM is fast with:
|
||||
- State-of-the-art serving throughput
|
||||
- Efficient management of attention key and value memory with **PagedAttention**
|
||||
- Continuous batching of incoming requests
|
||||
- Fast model execution with CUDA/HIP graph
|
||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
||||
- Optimized CUDA kernels
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
@ -43,25 +57,42 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs and AMD GPUs
|
||||
- (Experimental) Prefix caching support
|
||||
- (Experimental) Multi-lora support
|
||||
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
|
||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
|
||||
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
|
||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
|
||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
|
||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
||||
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
|
||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
||||
|
||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
|
385
benchmarks/backend_request_func.py
Normal file
385
benchmarks/backend_request_func.py
Normal file
@ -0,0 +1,385 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
prompt: str
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
best_of: int = 1
|
||||
use_beam_search: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0
|
||||
ttft: float = 0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
prompt_len: int = 0
|
||||
error: str = ""
|
||||
|
||||
|
||||
async def async_request_tgi(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||
}
|
||||
payload = {
|
||||
"inputs": request_func_input.prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.latency = most_recent_timestamp - st
|
||||
output.success = True
|
||||
output.generated_text = data["generated_text"]
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_trt_llm(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.latency = most_recent_timestamp - st
|
||||
output.generated_text = json.loads(data)["text_output"]
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_deepspeed_mii(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
assert not request_func_input.use_beam_search
|
||||
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
|
||||
"top_p": 1.0,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
# NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
|
||||
# will use 0 as placeholder.
|
||||
# See https://github.com/microsoft/DeepSpeed-MII/pull/311
|
||||
output.ttft = 0
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=request_func_input.api_url,
|
||||
json=payload) as response:
|
||||
if response.status == 200:
|
||||
parsed_resp = await response.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"v1/completions"
|
||||
), "OpenAI Completions API URL must end with 'v1/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
if data["choices"][0]["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# do not want to include as inter-token-latency
|
||||
elif data.get("usage", None) is None:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(
|
||||
"v1/chat/completions"
|
||||
), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": request_func_input.prompt,
|
||||
},
|
||||
],
|
||||
"temperature": 0.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
if "content" in data["choices"][0]["delta"]:
|
||||
# First token
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += data["choices"][0]["delta"][
|
||||
"content"]
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
|
||||
# introduced in Python 3.9
|
||||
def remove_prefix(text: str, prefix: str) -> str:
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix):]
|
||||
return text
|
||||
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"tgi": async_request_tgi,
|
||||
"vllm": async_request_openai_completions,
|
||||
"lmdeploy": async_request_openai_completions,
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,19 +14,21 @@ from vllm import LLM, SamplingParams
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
# Process all the requests in a single batch if possible.
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
llm = LLM(model=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
)
|
||||
enforce_eager=args.enforce_eager,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
device=args.device,
|
||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
download_dir=args.download_dir,
|
||||
block_size=args.block_size)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
n=args.n,
|
||||
@ -35,30 +39,50 @@ def main(args: argparse.Namespace):
|
||||
max_tokens=args.output_len,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||
|
||||
def run_to_completion(profile: bool = False):
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print(p.key_averages())
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
run_to_completion(profile=False)
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = args.profile_result_dir
|
||||
if not profile_dir:
|
||||
profile_dir = Path(
|
||||
"."
|
||||
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile=False))
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||
|
||||
|
||||
@ -70,7 +94,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'squeezellm', None],
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
@ -97,5 +121,51 @@ if __name__ == '__main__':
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default='auto',
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument(
|
||||
'--profile',
|
||||
action='store_true',
|
||||
help='profile the generation process of a single batch')
|
||||
parser.add_argument(
|
||||
'--profile-result-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help=('path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard.'))
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=16,
|
||||
help='block size of key/value cache')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument(
|
||||
"--ray-workers-use-nsight",
|
||||
action='store_true',
|
||||
help="If specified, use nsight to profile ray workers",
|
||||
)
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
52
benchmarks/benchmark_prefix_caching.py
Normal file
52
benchmarks/benchmark_prefix_caching.py
Normal file
@ -0,0 +1,52 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
|
||||
|
||||
|
||||
def test_prefix(llm=None, sampling_params=None, prompts=None):
|
||||
start_time = time.time()
|
||||
|
||||
llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"cost time {end_time - start_time}")
|
||||
|
||||
|
||||
def main(args):
|
||||
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
|
||||
tokenizer_mode='auto',
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=args.enable_prefix_caching)
|
||||
|
||||
num_prompts = 100
|
||||
prompts = [PROMPT] * num_prompts
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||
|
||||
print("------warm up------")
|
||||
test_prefix(
|
||||
llm=llm,
|
||||
prompts=prompts[:1],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("------start generating------")
|
||||
test_prefix(
|
||||
llm=llm,
|
||||
prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the performance with or without automatic '
|
||||
'prefix caching.')
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='enable prefix caching')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -1,37 +1,60 @@
|
||||
"""Benchmark online serving throughput.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
(vLLM backend)
|
||||
python -m vllm.entrypoints.api_server \
|
||||
vLLM OpenAI API server
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model <your_model> --swap-space 16 \
|
||||
--disable-log-requests
|
||||
|
||||
(TGI backend)
|
||||
./launch_hf_server.sh <your_model>
|
||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
||||
|
||||
On the client side, run:
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend <backend> \
|
||||
--tokenizer <your_model> --dataset <target_dataset> \
|
||||
--request-rate <request_rate>
|
||||
--model <your_model> \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path <path to dataset> \
|
||||
--request-rate <request_rate> \ # By default <request_rate> is inf
|
||||
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# (prompt len, output len, latency)
|
||||
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
total_output: int
|
||||
request_throughput: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
|
||||
|
||||
def sample_requests(
|
||||
def sample_sharegpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -40,15 +63,15 @@ def sample_requests(
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# some of these will be filtered out, so sample more than we need
|
||||
sampled_indices = random.sample(range(len(dataset)),
|
||||
int(num_requests * 1.2))
|
||||
dataset = [dataset[i] for i in sampled_indices]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
@ -79,6 +102,73 @@ def sample_requests(
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_sonnet_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
prefix_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, str, int, int]]:
|
||||
assert input_len > prefix_len, "input_len must be greater than prefix_len."
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
poem_lines = f.readlines()
|
||||
|
||||
# Tokenize the poem lines.
|
||||
poem_token_ids = tokenizer(poem_lines).input_ids
|
||||
average_poem_len = sum(
|
||||
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
|
||||
|
||||
# Base prefix for all requests.
|
||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
||||
base_message = [{
|
||||
"role": "user",
|
||||
"content": base_prompt,
|
||||
}]
|
||||
base_prompt_formatted = tokenizer.apply_chat_template(
|
||||
base_message, add_generation_prompt=True, tokenize=False)
|
||||
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
||||
|
||||
assert (input_len > base_prompt_offset
|
||||
), f"Please set 'args.input-len' higher than {base_prompt_offset}."
|
||||
num_input_lines = round(
|
||||
(input_len - base_prompt_offset) / average_poem_len)
|
||||
|
||||
# First approximately `prefix_len` number of tokens in the
|
||||
# prompt are fixed poem lines.
|
||||
assert (
|
||||
prefix_len > base_prompt_offset
|
||||
), f"Please set 'args.prefix-len' higher than {base_prompt_offset}."
|
||||
|
||||
num_prefix_lines = round(
|
||||
(prefix_len - base_prompt_offset) / average_poem_len)
|
||||
prefix_lines = poem_lines[:num_prefix_lines]
|
||||
|
||||
# Sample the rest of lines per request.
|
||||
sampled_requests: List[Tuple[str, int, int]] = []
|
||||
for _ in range(num_requests):
|
||||
sampled_lines = "".join(
|
||||
prefix_lines +
|
||||
random.sample(poem_lines, num_input_lines - num_prefix_lines))
|
||||
|
||||
prompt = f"{base_prompt}{sampled_lines}"
|
||||
message = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
},
|
||||
]
|
||||
prompt_formatted = tokenizer.apply_chat_template(
|
||||
message, add_generation_prompt=True, tokenize=False)
|
||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
||||
sampled_requests.append(
|
||||
(prompt, prompt_formatted, prompt_len, output_len))
|
||||
|
||||
return sampled_requests
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
@ -96,79 +186,149 @@ async def get_request(
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def send_request(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
prompt: str,
|
||||
prompt_len: int,
|
||||
output_len: int,
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
) -> None:
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
headers = {"User-Agent": "Benchmark Client"}
|
||||
if backend == "vllm":
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"n": 1,
|
||||
"best_of": best_of,
|
||||
"use_beam_search": use_beam_search,
|
||||
"temperature": 0.0 if use_beam_search else 1.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"stream": False,
|
||||
}
|
||||
elif backend == "tgi":
|
||||
assert not use_beam_search
|
||||
params = {
|
||||
"best_of": best_of,
|
||||
"max_new_tokens": output_len,
|
||||
"do_sample": True,
|
||||
}
|
||||
pload = {
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
def calculate_metrics(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
outputs: List[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens = []
|
||||
total_input = 0
|
||||
completed = 0
|
||||
tpots = []
|
||||
ttfts = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
output_len = len(tokenizer(outputs[i].generated_text).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
if output_len > 1:
|
||||
tpots.append(
|
||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||
ttfts.append(outputs[i].ttft)
|
||||
completed += 1
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
actual_output_lens.append(0)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
while True:
|
||||
async with session.post(api_url, headers=headers, json=pload) as response:
|
||||
chunks = []
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunks.append(chunk)
|
||||
output = b"".join(chunks).decode("utf-8")
|
||||
output = json.loads(output)
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
total_output=sum(actual_output_lens),
|
||||
request_throughput=completed / dur_s,
|
||||
input_throughput=total_input / dur_s,
|
||||
output_throughput=sum(actual_output_lens) / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
||||
1000, # ttfts is empty if streaming is not supported by backend
|
||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(tpots) * 1000,
|
||||
median_tpot_ms=np.median(tpots) * 1000,
|
||||
p99_tpot_ms=np.percentile(tpots, 99) * 1000,
|
||||
)
|
||||
|
||||
# Re-send the request if it failed.
|
||||
if "error" not in output:
|
||||
break
|
||||
|
||||
request_end_time = time.perf_counter()
|
||||
request_latency = request_end_time - request_start_time
|
||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||
return metrics, actual_output_lens
|
||||
|
||||
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
request_rate: float,
|
||||
) -> None:
|
||||
tasks: List[asyncio.Task] = []
|
||||
disable_tqdm: bool,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS.get(backend)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
||||
prompt_len, output_len,
|
||||
best_of, use_beam_search))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if not disable_tqdm:
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
metrics, actual_output_lens = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
||||
benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
|
||||
metrics.input_throughput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
metrics.output_throughput))
|
||||
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
|
||||
metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
|
||||
n=50,
|
||||
c='-'))
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
||||
metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("=" * 50)
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
"output_lens": actual_output_lens,
|
||||
"ttfts": [output.ttft for output in outputs],
|
||||
"itls": [output.itl for output in outputs],
|
||||
"generated_texts": [output.generated_text for output in outputs],
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -176,58 +336,252 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
||||
args.use_beam_search, args.request_rate))
|
||||
benchmark_end_time = time.perf_counter()
|
||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||
print(f"Total time: {benchmark_time:.2f} s")
|
||||
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
|
||||
# Compute the latency statistics.
|
||||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
||||
print(f"Average latency: {avg_latency:.2f} s")
|
||||
avg_per_token_latency = np.mean([
|
||||
latency / (prompt_len + output_len)
|
||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||
avg_per_output_token_latency = np.mean([
|
||||
latency / output_len
|
||||
for _, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print("Average latency per output token: "
|
||||
f"{avg_per_output_token_latency:.2f} s")
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
|
||||
if args.dataset is not None:
|
||||
warnings.warn(
|
||||
"The '--dataset' argument will be deprecated in the next "
|
||||
"release. Please use '--dataset-name' and "
|
||||
"'--dataset-path' in the future runs.",
|
||||
stacklevel=2)
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sharegpt":
|
||||
input_requests = sample_sharegpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
# Do not format the prompt, pass to message directly
|
||||
if args.backend == "openai-chat":
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
prefix_len=args.prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt, prompt_len, output_len)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len in input_requests]
|
||||
else:
|
||||
assert (
|
||||
tokenizer.chat_template or tokenizer.default_chat_template
|
||||
), "Tokenizer/model must have chat template for sonnet dataset."
|
||||
input_requests = sample_sonnet_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
prefix_len=args.prefix_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
input_requests = [(prompt_formatted, prompt_len, output_len)
|
||||
for prompt, prompt_formatted, prompt_len,
|
||||
output_len in input_requests]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
best_of=args.best_of,
|
||||
use_beam_search=args.use_beam_search,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
))
|
||||
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
result_json = {}
|
||||
|
||||
# Setup
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
result_json["date"] = current_dt
|
||||
result_json["backend"] = backend
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["use_beam_search"] = args.use_beam_search
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Metadata
|
||||
if args.metadata:
|
||||
for item in args.metadata:
|
||||
if "=" in item:
|
||||
kvstring = item.split("=")
|
||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metadata format. Please use KEY=VALUE format."
|
||||
)
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (
|
||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
|
||||
if args.result_dir:
|
||||
file_name = os.path.join(args.result_dir, file_name)
|
||||
with open(file_name, "w") as outfile:
|
||||
json.dump(result_json, outfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
parser.add_argument("--backend", type=str, default="vllm",
|
||||
choices=["vllm", "tgi"])
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="vllm",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server or API base url if not using http host and port.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
type=str,
|
||||
default="/v1/completions",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the ShareGPT dataset, will be deprecated in the "
|
||||
"next release.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "sonnet"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument("--dataset-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--tokenizer", type=str, required=True,
|
||||
help="Name or path of the tokenizer.")
|
||||
parser.add_argument("--best-of", type=int, default=1,
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.")
|
||||
"returns the best one.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-input-len",
|
||||
type=int,
|
||||
default=550,
|
||||
help=
|
||||
"Number of input tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-output-len",
|
||||
type=int,
|
||||
default=150,
|
||||
help=
|
||||
"Number of output tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sonnet-prefix-len",
|
||||
type=int,
|
||||
default=200,
|
||||
help=
|
||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, "
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.")
|
||||
"the request arrival times.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Trust remote code from huggingface",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
action="store_true",
|
||||
help="Specify to disable tqdm progress bar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-result",
|
||||
action="store_true",
|
||||
help="Specify to save benchmark results to a json file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
metavar="KEY=VALUE",
|
||||
nargs="*",
|
||||
help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
|
||||
"for metadata of this run to be saved in the result JSON file "
|
||||
"for record keeping purposes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify directory to save benchmark json results."
|
||||
"If not specified, results are saved in the current directory.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -6,9 +6,9 @@ import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -17,8 +17,7 @@ def sample_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None:
|
||||
if fixed_output_len < 4:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset.
|
||||
@ -70,17 +69,29 @@ def run_vllm(
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
download_dir: Optional[str] = None,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
model=model,
|
||||
llm = LLM(model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
)
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir)
|
||||
|
||||
# Add the requests to the engine.
|
||||
for prompt, _, output_len in requests:
|
||||
@ -172,13 +183,15 @@ def run_mii(
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import pipeline
|
||||
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
||||
from mii import client, serve
|
||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm(prompts, max_new_tokens=output_len)
|
||||
llm.generate(prompts, max_new_tokens=output_len)
|
||||
end = time.perf_counter()
|
||||
client = client(model)
|
||||
client.terminate_server()
|
||||
return end - start
|
||||
|
||||
|
||||
@ -202,7 +215,11 @@ def main(args: argparse.Namespace):
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype)
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype, args.device,
|
||||
args.enable_prefix_caching,
|
||||
args.gpu_memory_utilization, args.download_dir)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -242,7 +259,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'squeezellm', None],
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n",
|
||||
@ -262,6 +279,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Maximum length of a sequence (including prompt and output). '
|
||||
'If None, will be derived from the model.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
@ -271,6 +294,37 @@ if __name__ == "__main__":
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=0.9,
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
parser.add_argument("--enforce-eager",
|
||||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of huggingface')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
182
benchmarks/kernels/benchmark_mixtral_moe.py
Normal file
182
benchmarks/kernels/benchmark_mixtral_moe.py
Normal file
@ -0,0 +1,182 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
||||
get_config_file_name)
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
|
||||
def main():
|
||||
method = fused_moe
|
||||
for bs in [
|
||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
||||
2048, 3072, 4096
|
||||
]:
|
||||
run_grid(bs, method=method)
|
||||
|
||||
|
||||
def run_grid(bs, method):
|
||||
d_model = 4096
|
||||
num_total_experts = 8
|
||||
top_k = 2
|
||||
tp_size = 2
|
||||
model_intermediate_size = 14336
|
||||
num_layers = 32
|
||||
num_calls = 100
|
||||
|
||||
num_warmup_trials = 1
|
||||
num_trials = 1
|
||||
|
||||
configs = []
|
||||
if bs <= 16:
|
||||
BLOCK_SIZES_M = [16]
|
||||
elif bs <= 32:
|
||||
BLOCK_SIZES_M = [16, 32]
|
||||
elif bs <= 64:
|
||||
BLOCK_SIZES_M = [16, 32, 64]
|
||||
elif bs <= 128:
|
||||
BLOCK_SIZES_M = [16, 32, 64, 128]
|
||||
else:
|
||||
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
|
||||
|
||||
for block_size_n in [32, 64, 128, 256]:
|
||||
for block_size_m in BLOCK_SIZES_M:
|
||||
for block_size_k in [64, 128, 256]:
|
||||
for group_size_m in [1, 16, 32, 64]:
|
||||
for num_warps in [4, 8]:
|
||||
configs.append({
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": 4,
|
||||
})
|
||||
|
||||
best_config = None
|
||||
best_time_us = 1e20
|
||||
|
||||
for config in configs:
|
||||
print(f'{tp_size=} {bs=}')
|
||||
print(f'{config}')
|
||||
# warmup
|
||||
print('warming up')
|
||||
try:
|
||||
for _ in range(num_warmup_trials):
|
||||
run_timing(
|
||||
num_calls=num_calls,
|
||||
bs=bs,
|
||||
d_model=d_model,
|
||||
num_total_experts=num_total_experts,
|
||||
top_k=top_k,
|
||||
tp_size=tp_size,
|
||||
model_intermediate_size=model_intermediate_size,
|
||||
method=method,
|
||||
config=config,
|
||||
)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
continue
|
||||
|
||||
# trial
|
||||
print('benchmarking')
|
||||
for _ in range(num_trials):
|
||||
kernel_dur_ms = run_timing(
|
||||
num_calls=num_calls,
|
||||
bs=bs,
|
||||
d_model=d_model,
|
||||
num_total_experts=num_total_experts,
|
||||
top_k=top_k,
|
||||
tp_size=tp_size,
|
||||
model_intermediate_size=model_intermediate_size,
|
||||
method=method,
|
||||
config=config,
|
||||
)
|
||||
|
||||
kernel_dur_us = 1000 * kernel_dur_ms
|
||||
model_dur_ms = kernel_dur_ms * num_layers
|
||||
|
||||
if kernel_dur_us < best_time_us:
|
||||
best_config = config
|
||||
best_time_us = kernel_dur_us
|
||||
|
||||
print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
||||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
||||
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
||||
|
||||
print("best_time_us", best_time_us)
|
||||
print("best_config", best_config)
|
||||
|
||||
# holds Dict[str, Dict[str, int]]
|
||||
filename = get_config_file_name(num_total_experts,
|
||||
model_intermediate_size // tp_size)
|
||||
print(f"writing config to file {filename}")
|
||||
existing_content = {}
|
||||
if os.path.exists(filename):
|
||||
with open(filename, "r") as f:
|
||||
existing_content = json.load(f)
|
||||
existing_content[str(bs)] = best_config
|
||||
with open(filename, "w") as f:
|
||||
json.dump(existing_content, f, indent=4)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
||||
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
||||
config) -> float:
|
||||
shard_intermediate_size = model_intermediate_size // tp_size
|
||||
|
||||
hidden_states = torch.rand(
|
||||
(bs, d_model),
|
||||
device="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
ws = torch.rand(
|
||||
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
w2s = torch.rand(
|
||||
(num_total_experts, d_model, shard_intermediate_size),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
gating_output = F.softmax(torch.rand(
|
||||
(num_calls, bs, num_total_experts),
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
dim=-1)
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
for i in range(num_calls):
|
||||
hidden_states = method(
|
||||
hidden_states=hidden_states,
|
||||
w1=ws,
|
||||
w2=w2s,
|
||||
gating_output=gating_output[i],
|
||||
topk=2,
|
||||
renormalize=True,
|
||||
inplace=True,
|
||||
override_config=config,
|
||||
)
|
||||
end_event.record()
|
||||
end_event.synchronize()
|
||||
|
||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
||||
return dur_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@ -1,10 +1,12 @@
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import attention_ops
|
||||
from vllm._C import ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
PARTITION_SIZE = 512
|
||||
@ -23,9 +25,12 @@ def main(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
do_profile: bool,
|
||||
device: str = "cuda",
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
@ -33,23 +38,19 @@ def main(
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=device)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
device=device)
|
||||
|
||||
context_lens = [context_len for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
@ -60,18 +61,18 @@ def main(
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
|
||||
|
||||
# Create the KV cache.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
||||
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
|
||||
key_cache.uniform_(-scale, scale)
|
||||
value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
|
||||
value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
value_cache.uniform_(-scale, scale)
|
||||
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
@ -90,7 +91,7 @@ def main(
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
@ -98,21 +99,22 @@ def main(
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
attention_ops.paged_attention_v1(
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
elif version == "v2":
|
||||
attention_ops.paged_attention_v2(
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
@ -120,13 +122,14 @@ def main(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
@ -139,6 +142,7 @@ def main(
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=3, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
@ -172,16 +176,19 @@ if __name__ == '__main__':
|
||||
default="half")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.num_query_heads % args.num_kv_heads != 0:
|
||||
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||
dtype_to_torch_dtype = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
}
|
||||
main(
|
||||
version=args.version,
|
||||
num_seqs=args.batch_size,
|
||||
@ -191,7 +198,8 @@ if __name__ == '__main__':
|
||||
head_size=args.head_size,
|
||||
block_size=args.block_size,
|
||||
use_alibi=args.use_alibi,
|
||||
dtype=dtype_to_torch_dtype[args.dtype],
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
)
|
||||
|
121
benchmarks/kernels/benchmark_rope.py
Normal file
121
benchmarks/kernels/benchmark_rope.py
Normal file
@ -0,0 +1,121 @@
|
||||
import argparse
|
||||
from itertools import accumulate
|
||||
from typing import Optional
|
||||
|
||||
import nvtx
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
|
||||
def benchmark_rope_kernels_multi_lora(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
# silulating serving 4 LoRAs
|
||||
scaling_factors = [1, 2, 4, 8]
|
||||
# batched RoPE can take multiple scaling factors
|
||||
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, {
|
||||
"type": "linear",
|
||||
"factor": tuple(scaling_factors)
|
||||
})
|
||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
||||
# instances to simulate the same behavior
|
||||
non_batched_ropes = []
|
||||
for scaling_factor in scaling_factors:
|
||||
non_batched_ropes.append(
|
||||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
{
|
||||
"type": "linear",
|
||||
"factor": (scaling_factor, )
|
||||
}))
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# create query offsets for batched RoPE, we concat multiple kv cache
|
||||
# together and each query needs to find the right kv cache of its type
|
||||
offset_map = torch.tensor(
|
||||
list(
|
||||
accumulate([0] + [
|
||||
max_position * scaling_factor * 2
|
||||
for scaling_factor in scaling_factors[:-1]
|
||||
])))
|
||||
query_types = torch.randint(0,
|
||||
len(scaling_factors), (batch_size, seq_len),
|
||||
device=device)
|
||||
# map query types to offsets
|
||||
query_offsets = offset_map[query_types]
|
||||
# the kernel takes flattened offsets
|
||||
flatten_offsets = query_offsets.flatten()
|
||||
|
||||
# batched queries of the same type together for non-batched RoPE
|
||||
queries = [query[query_types == i] for i in range(len(scaling_factors))]
|
||||
keys = [key[query_types == i] for i in range(len(scaling_factors))]
|
||||
packed_qkr = zip(queries, keys, non_batched_ropes)
|
||||
# synchronize before start timing
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("non-batched", color="yellow"):
|
||||
for q, k, r in packed_qkr:
|
||||
r.forward(positions, q, k)
|
||||
torch.cuda.synchronize()
|
||||
with nvtx.annotate("batched", color="green"):
|
||||
batched_rope.forward(positions, query, key, flatten_offsets)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels.")
|
||||
parser.add_argument("--is-neox-style", type=bool, default=True)
|
||||
parser.add_argument("--batch-size", type=int, default=16)
|
||||
parser.add_argument("--seq-len", type=int, default=512)
|
||||
parser.add_argument("--num-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
default=128)
|
||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["bfloat16", "float"],
|
||||
default="float")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
choices=["cuda:0", "cuda:1"],
|
||||
default="cuda:0")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
benchmark_rope_kernels_multi_lora(
|
||||
is_neox_style=args.is_neox_style,
|
||||
batch_size=args.batch_size,
|
||||
seq_len=args.seq_len,
|
||||
num_heads=args.num_heads,
|
||||
head_size=args.head_size,
|
||||
rotary_dim=args.rotary_dim,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
device=args.device,
|
||||
)
|
@ -6,7 +6,7 @@ TOKENS=$2
|
||||
|
||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
||||
-v $PWD/data:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:0.8 \
|
||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||
--model-id $MODEL \
|
||||
--sharded false \
|
||||
--max-input-length 1024 \
|
||||
|
518
benchmarks/sonnet.txt
Normal file
518
benchmarks/sonnet.txt
Normal file
@ -0,0 +1,518 @@
|
||||
FROM fairest creatures we desire increase,
|
||||
That thereby beauty's rose might never die,
|
||||
But as the riper should by time decease,
|
||||
His tender heir might bear his memory:
|
||||
But thou, contracted to thine own bright eyes,
|
||||
Feed'st thy light'st flame with self-substantial fuel,
|
||||
Making a famine where abundance lies,
|
||||
Thyself thy foe, to thy sweet self too cruel.
|
||||
Thou that art now the world's fresh ornament
|
||||
And only herald to the gaudy spring,
|
||||
Within thine own bud buriest thy content
|
||||
And, tender churl, makest waste in niggarding.
|
||||
Pity the world, or else this glutton be,
|
||||
To eat the world's due, by the grave and thee.
|
||||
When forty winters shall beseige thy brow,
|
||||
And dig deep trenches in thy beauty's field,
|
||||
Thy youth's proud livery, so gazed on now,
|
||||
Will be a tatter'd weed, of small worth held:
|
||||
Then being ask'd where all thy beauty lies,
|
||||
Where all the treasure of thy lusty days,
|
||||
To say, within thine own deep-sunken eyes,
|
||||
Were an all-eating shame and thriftless praise.
|
||||
How much more praise deserved thy beauty's use,
|
||||
If thou couldst answer 'This fair child of mine
|
||||
Shall sum my count and make my old excuse,'
|
||||
Proving his beauty by succession thine!
|
||||
This were to be new made when thou art old,
|
||||
And see thy blood warm when thou feel'st it cold.
|
||||
Look in thy glass, and tell the face thou viewest
|
||||
Now is the time that face should form another;
|
||||
Whose fresh repair if now thou not renewest,
|
||||
Thou dost beguile the world, unbless some mother.
|
||||
For where is she so fair whose unear'd womb
|
||||
Disdains the tillage of thy husbandry?
|
||||
Or who is he so fond will be the tomb
|
||||
Of his self-love, to stop posterity?
|
||||
Thou art thy mother's glass, and she in thee
|
||||
Calls back the lovely April of her prime:
|
||||
So thou through windows of thine age shall see
|
||||
Despite of wrinkles this thy golden time.
|
||||
But if thou live, remember'd not to be,
|
||||
Die single, and thine image dies with thee.
|
||||
Unthrifty loveliness, why dost thou spend
|
||||
Upon thyself thy beauty's legacy?
|
||||
Nature's bequest gives nothing but doth lend,
|
||||
And being frank she lends to those are free.
|
||||
Then, beauteous niggard, why dost thou abuse
|
||||
The bounteous largess given thee to give?
|
||||
Profitless usurer, why dost thou use
|
||||
So great a sum of sums, yet canst not live?
|
||||
For having traffic with thyself alone,
|
||||
Thou of thyself thy sweet self dost deceive.
|
||||
Then how, when nature calls thee to be gone,
|
||||
What acceptable audit canst thou leave?
|
||||
Thy unused beauty must be tomb'd with thee,
|
||||
Which, used, lives th' executor to be.
|
||||
Those hours, that with gentle work did frame
|
||||
The lovely gaze where every eye doth dwell,
|
||||
Will play the tyrants to the very same
|
||||
And that unfair which fairly doth excel:
|
||||
For never-resting time leads summer on
|
||||
To hideous winter and confounds him there;
|
||||
Sap cheque'd with frost and lusty leaves quite gone,
|
||||
Beauty o'ersnow'd and bareness every where:
|
||||
Then, were not summer's distillation left,
|
||||
A liquid prisoner pent in walls of glass,
|
||||
Beauty's effect with beauty were bereft,
|
||||
Nor it nor no remembrance what it was:
|
||||
But flowers distill'd though they with winter meet,
|
||||
Leese but their show; their substance still lives sweet.
|
||||
Then let not winter's ragged hand deface
|
||||
In thee thy summer, ere thou be distill'd:
|
||||
Make sweet some vial; treasure thou some place
|
||||
With beauty's treasure, ere it be self-kill'd.
|
||||
That use is not forbidden usury,
|
||||
Which happies those that pay the willing loan;
|
||||
That's for thyself to breed another thee,
|
||||
Or ten times happier, be it ten for one;
|
||||
Ten times thyself were happier than thou art,
|
||||
If ten of thine ten times refigured thee:
|
||||
Then what could death do, if thou shouldst depart,
|
||||
Leaving thee living in posterity?
|
||||
Be not self-will'd, for thou art much too fair
|
||||
To be death's conquest and make worms thine heir.
|
||||
Lo! in the orient when the gracious light
|
||||
Lifts up his burning head, each under eye
|
||||
Doth homage to his new-appearing sight,
|
||||
Serving with looks his sacred majesty;
|
||||
And having climb'd the steep-up heavenly hill,
|
||||
Resembling strong youth in his middle age,
|
||||
yet mortal looks adore his beauty still,
|
||||
Attending on his golden pilgrimage;
|
||||
But when from highmost pitch, with weary car,
|
||||
Like feeble age, he reeleth from the day,
|
||||
The eyes, 'fore duteous, now converted are
|
||||
From his low tract and look another way:
|
||||
So thou, thyself out-going in thy noon,
|
||||
Unlook'd on diest, unless thou get a son.
|
||||
Music to hear, why hear'st thou music sadly?
|
||||
Sweets with sweets war not, joy delights in joy.
|
||||
Why lovest thou that which thou receivest not gladly,
|
||||
Or else receivest with pleasure thine annoy?
|
||||
If the true concord of well-tuned sounds,
|
||||
By unions married, do offend thine ear,
|
||||
They do but sweetly chide thee, who confounds
|
||||
In singleness the parts that thou shouldst bear.
|
||||
Mark how one string, sweet husband to another,
|
||||
Strikes each in each by mutual ordering,
|
||||
Resembling sire and child and happy mother
|
||||
Who all in one, one pleasing note do sing:
|
||||
Whose speechless song, being many, seeming one,
|
||||
Sings this to thee: 'thou single wilt prove none.'
|
||||
Is it for fear to wet a widow's eye
|
||||
That thou consumest thyself in single life?
|
||||
Ah! if thou issueless shalt hap to die.
|
||||
The world will wail thee, like a makeless wife;
|
||||
The world will be thy widow and still weep
|
||||
That thou no form of thee hast left behind,
|
||||
When every private widow well may keep
|
||||
By children's eyes her husband's shape in mind.
|
||||
Look, what an unthrift in the world doth spend
|
||||
Shifts but his place, for still the world enjoys it;
|
||||
But beauty's waste hath in the world an end,
|
||||
And kept unused, the user so destroys it.
|
||||
No love toward others in that bosom sits
|
||||
That on himself such murderous shame commits.
|
||||
For shame! deny that thou bear'st love to any,
|
||||
Who for thyself art so unprovident.
|
||||
Grant, if thou wilt, thou art beloved of many,
|
||||
But that thou none lovest is most evident;
|
||||
For thou art so possess'd with murderous hate
|
||||
That 'gainst thyself thou stick'st not to conspire.
|
||||
Seeking that beauteous roof to ruinate
|
||||
Which to repair should be thy chief desire.
|
||||
O, change thy thought, that I may change my mind!
|
||||
Shall hate be fairer lodged than gentle love?
|
||||
Be, as thy presence is, gracious and kind,
|
||||
Or to thyself at least kind-hearted prove:
|
||||
Make thee another self, for love of me,
|
||||
That beauty still may live in thine or thee.
|
||||
As fast as thou shalt wane, so fast thou growest
|
||||
In one of thine, from that which thou departest;
|
||||
And that fresh blood which youngly thou bestowest
|
||||
Thou mayst call thine when thou from youth convertest.
|
||||
Herein lives wisdom, beauty and increase:
|
||||
Without this, folly, age and cold decay:
|
||||
If all were minded so, the times should cease
|
||||
And threescore year would make the world away.
|
||||
Let those whom Nature hath not made for store,
|
||||
Harsh featureless and rude, barrenly perish:
|
||||
Look, whom she best endow'd she gave the more;
|
||||
Which bounteous gift thou shouldst in bounty cherish:
|
||||
She carved thee for her seal, and meant thereby
|
||||
Thou shouldst print more, not let that copy die.
|
||||
When I do count the clock that tells the time,
|
||||
And see the brave day sunk in hideous night;
|
||||
When I behold the violet past prime,
|
||||
And sable curls all silver'd o'er with white;
|
||||
When lofty trees I see barren of leaves
|
||||
Which erst from heat did canopy the herd,
|
||||
And summer's green all girded up in sheaves
|
||||
Borne on the bier with white and bristly beard,
|
||||
Then of thy beauty do I question make,
|
||||
That thou among the wastes of time must go,
|
||||
Since sweets and beauties do themselves forsake
|
||||
And die as fast as they see others grow;
|
||||
And nothing 'gainst Time's scythe can make defence
|
||||
Save breed, to brave him when he takes thee hence.
|
||||
O, that you were yourself! but, love, you are
|
||||
No longer yours than you yourself here live:
|
||||
Against this coming end you should prepare,
|
||||
And your sweet semblance to some other give.
|
||||
So should that beauty which you hold in lease
|
||||
Find no determination: then you were
|
||||
Yourself again after yourself's decease,
|
||||
When your sweet issue your sweet form should bear.
|
||||
Who lets so fair a house fall to decay,
|
||||
Which husbandry in honour might uphold
|
||||
Against the stormy gusts of winter's day
|
||||
And barren rage of death's eternal cold?
|
||||
O, none but unthrifts! Dear my love, you know
|
||||
You had a father: let your son say so.
|
||||
Not from the stars do I my judgment pluck;
|
||||
And yet methinks I have astronomy,
|
||||
But not to tell of good or evil luck,
|
||||
Of plagues, of dearths, or seasons' quality;
|
||||
Nor can I fortune to brief minutes tell,
|
||||
Pointing to each his thunder, rain and wind,
|
||||
Or say with princes if it shall go well,
|
||||
By oft predict that I in heaven find:
|
||||
But from thine eyes my knowledge I derive,
|
||||
And, constant stars, in them I read such art
|
||||
As truth and beauty shall together thrive,
|
||||
If from thyself to store thou wouldst convert;
|
||||
Or else of thee this I prognosticate:
|
||||
Thy end is truth's and beauty's doom and date.
|
||||
When I consider every thing that grows
|
||||
Holds in perfection but a little moment,
|
||||
That this huge stage presenteth nought but shows
|
||||
Whereon the stars in secret influence comment;
|
||||
When I perceive that men as plants increase,
|
||||
Cheered and cheque'd even by the self-same sky,
|
||||
Vaunt in their youthful sap, at height decrease,
|
||||
And wear their brave state out of memory;
|
||||
Then the conceit of this inconstant stay
|
||||
Sets you most rich in youth before my sight,
|
||||
Where wasteful Time debateth with Decay,
|
||||
To change your day of youth to sullied night;
|
||||
And all in war with Time for love of you,
|
||||
As he takes from you, I engraft you new.
|
||||
But wherefore do not you a mightier way
|
||||
Make war upon this bloody tyrant, Time?
|
||||
And fortify yourself in your decay
|
||||
With means more blessed than my barren rhyme?
|
||||
Now stand you on the top of happy hours,
|
||||
And many maiden gardens yet unset
|
||||
With virtuous wish would bear your living flowers,
|
||||
Much liker than your painted counterfeit:
|
||||
So should the lines of life that life repair,
|
||||
Which this, Time's pencil, or my pupil pen,
|
||||
Neither in inward worth nor outward fair,
|
||||
Can make you live yourself in eyes of men.
|
||||
To give away yourself keeps yourself still,
|
||||
And you must live, drawn by your own sweet skill.
|
||||
Who will believe my verse in time to come,
|
||||
If it were fill'd with your most high deserts?
|
||||
Though yet, heaven knows, it is but as a tomb
|
||||
Which hides your life and shows not half your parts.
|
||||
If I could write the beauty of your eyes
|
||||
And in fresh numbers number all your graces,
|
||||
The age to come would say 'This poet lies:
|
||||
Such heavenly touches ne'er touch'd earthly faces.'
|
||||
So should my papers yellow'd with their age
|
||||
Be scorn'd like old men of less truth than tongue,
|
||||
And your true rights be term'd a poet's rage
|
||||
And stretched metre of an antique song:
|
||||
But were some child of yours alive that time,
|
||||
You should live twice; in it and in my rhyme.
|
||||
Shall I compare thee to a summer's day?
|
||||
Thou art more lovely and more temperate:
|
||||
Rough winds do shake the darling buds of May,
|
||||
And summer's lease hath all too short a date:
|
||||
Sometime too hot the eye of heaven shines,
|
||||
And often is his gold complexion dimm'd;
|
||||
And every fair from fair sometime declines,
|
||||
By chance or nature's changing course untrimm'd;
|
||||
But thy eternal summer shall not fade
|
||||
Nor lose possession of that fair thou owest;
|
||||
Nor shall Death brag thou wander'st in his shade,
|
||||
When in eternal lines to time thou growest:
|
||||
So long as men can breathe or eyes can see,
|
||||
So long lives this and this gives life to thee.
|
||||
Devouring Time, blunt thou the lion's paws,
|
||||
And make the earth devour her own sweet brood;
|
||||
Pluck the keen teeth from the fierce tiger's jaws,
|
||||
And burn the long-lived phoenix in her blood;
|
||||
Make glad and sorry seasons as thou fleets,
|
||||
And do whate'er thou wilt, swift-footed Time,
|
||||
To the wide world and all her fading sweets;
|
||||
But I forbid thee one most heinous crime:
|
||||
O, carve not with thy hours my love's fair brow,
|
||||
Nor draw no lines there with thine antique pen;
|
||||
Him in thy course untainted do allow
|
||||
For beauty's pattern to succeeding men.
|
||||
Yet, do thy worst, old Time: despite thy wrong,
|
||||
My love shall in my verse ever live young.
|
||||
A woman's face with Nature's own hand painted
|
||||
Hast thou, the master-mistress of my passion;
|
||||
A woman's gentle heart, but not acquainted
|
||||
With shifting change, as is false women's fashion;
|
||||
An eye more bright than theirs, less false in rolling,
|
||||
Gilding the object whereupon it gazeth;
|
||||
A man in hue, all 'hues' in his controlling,
|
||||
Much steals men's eyes and women's souls amazeth.
|
||||
And for a woman wert thou first created;
|
||||
Till Nature, as she wrought thee, fell a-doting,
|
||||
And by addition me of thee defeated,
|
||||
By adding one thing to my purpose nothing.
|
||||
But since she prick'd thee out for women's pleasure,
|
||||
Mine be thy love and thy love's use their treasure.
|
||||
So is it not with me as with that Muse
|
||||
Stirr'd by a painted beauty to his verse,
|
||||
Who heaven itself for ornament doth use
|
||||
And every fair with his fair doth rehearse
|
||||
Making a couplement of proud compare,
|
||||
With sun and moon, with earth and sea's rich gems,
|
||||
With April's first-born flowers, and all things rare
|
||||
That heaven's air in this huge rondure hems.
|
||||
O' let me, true in love, but truly write,
|
||||
And then believe me, my love is as fair
|
||||
As any mother's child, though not so bright
|
||||
As those gold candles fix'd in heaven's air:
|
||||
Let them say more than like of hearsay well;
|
||||
I will not praise that purpose not to sell.
|
||||
My glass shall not persuade me I am old,
|
||||
So long as youth and thou are of one date;
|
||||
But when in thee time's furrows I behold,
|
||||
Then look I death my days should expiate.
|
||||
For all that beauty that doth cover thee
|
||||
Is but the seemly raiment of my heart,
|
||||
Which in thy breast doth live, as thine in me:
|
||||
How can I then be elder than thou art?
|
||||
O, therefore, love, be of thyself so wary
|
||||
As I, not for myself, but for thee will;
|
||||
Bearing thy heart, which I will keep so chary
|
||||
As tender nurse her babe from faring ill.
|
||||
Presume not on thy heart when mine is slain;
|
||||
Thou gavest me thine, not to give back again.
|
||||
As an unperfect actor on the stage
|
||||
Who with his fear is put besides his part,
|
||||
Or some fierce thing replete with too much rage,
|
||||
Whose strength's abundance weakens his own heart.
|
||||
So I, for fear of trust, forget to say
|
||||
The perfect ceremony of love's rite,
|
||||
And in mine own love's strength seem to decay,
|
||||
O'ercharged with burden of mine own love's might.
|
||||
O, let my books be then the eloquence
|
||||
And dumb presagers of my speaking breast,
|
||||
Who plead for love and look for recompense
|
||||
More than that tongue that more hath more express'd.
|
||||
O, learn to read what silent love hath writ:
|
||||
To hear with eyes belongs to love's fine wit.
|
||||
Mine eye hath play'd the painter and hath stell'd
|
||||
Thy beauty's form in table of my heart;
|
||||
My body is the frame wherein 'tis held,
|
||||
And perspective it is the painter's art.
|
||||
For through the painter must you see his skill,
|
||||
To find where your true image pictured lies;
|
||||
Which in my bosom's shop is hanging still,
|
||||
That hath his windows glazed with thine eyes.
|
||||
Now see what good turns eyes for eyes have done:
|
||||
Mine eyes have drawn thy shape, and thine for me
|
||||
Are windows to my breast, where-through the sun
|
||||
Delights to peep, to gaze therein on thee;
|
||||
Yet eyes this cunning want to grace their art;
|
||||
They draw but what they see, know not the heart.
|
||||
Let those who are in favour with their stars
|
||||
Of public honour and proud titles boast,
|
||||
Whilst I, whom fortune of such triumph bars,
|
||||
Unlook'd for joy in that I honour most.
|
||||
Great princes' favourites their fair leaves spread
|
||||
But as the marigold at the sun's eye,
|
||||
And in themselves their pride lies buried,
|
||||
For at a frown they in their glory die.
|
||||
The painful warrior famoused for fight,
|
||||
After a thousand victories once foil'd,
|
||||
Is from the book of honour razed quite,
|
||||
And all the rest forgot for which he toil'd:
|
||||
Then happy I, that love and am beloved
|
||||
Where I may not remove nor be removed.
|
||||
Lord of my love, to whom in vassalage
|
||||
Thy merit hath my duty strongly knit,
|
||||
To thee I send this written embassage,
|
||||
To witness duty, not to show my wit:
|
||||
Duty so great, which wit so poor as mine
|
||||
May make seem bare, in wanting words to show it,
|
||||
But that I hope some good conceit of thine
|
||||
In thy soul's thought, all naked, will bestow it;
|
||||
Till whatsoever star that guides my moving
|
||||
Points on me graciously with fair aspect
|
||||
And puts apparel on my tatter'd loving,
|
||||
To show me worthy of thy sweet respect:
|
||||
Then may I dare to boast how I do love thee;
|
||||
Till then not show my head where thou mayst prove me.
|
||||
Weary with toil, I haste me to my bed,
|
||||
The dear repose for limbs with travel tired;
|
||||
But then begins a journey in my head,
|
||||
To work my mind, when body's work's expired:
|
||||
For then my thoughts, from far where I abide,
|
||||
Intend a zealous pilgrimage to thee,
|
||||
And keep my drooping eyelids open wide,
|
||||
Looking on darkness which the blind do see
|
||||
Save that my soul's imaginary sight
|
||||
Presents thy shadow to my sightless view,
|
||||
Which, like a jewel hung in ghastly night,
|
||||
Makes black night beauteous and her old face new.
|
||||
Lo! thus, by day my limbs, by night my mind,
|
||||
For thee and for myself no quiet find.
|
||||
How can I then return in happy plight,
|
||||
That am debarr'd the benefit of rest?
|
||||
When day's oppression is not eased by night,
|
||||
But day by night, and night by day, oppress'd?
|
||||
And each, though enemies to either's reign,
|
||||
Do in consent shake hands to torture me;
|
||||
The one by toil, the other to complain
|
||||
How far I toil, still farther off from thee.
|
||||
I tell the day, to please them thou art bright
|
||||
And dost him grace when clouds do blot the heaven:
|
||||
So flatter I the swart-complexion'd night,
|
||||
When sparkling stars twire not thou gild'st the even.
|
||||
But day doth daily draw my sorrows longer
|
||||
And night doth nightly make grief's strength seem stronger.
|
||||
When, in disgrace with fortune and men's eyes,
|
||||
I all alone beweep my outcast state
|
||||
And trouble deal heaven with my bootless cries
|
||||
And look upon myself and curse my fate,
|
||||
Wishing me like to one more rich in hope,
|
||||
Featured like him, like him with friends possess'd,
|
||||
Desiring this man's art and that man's scope,
|
||||
With what I most enjoy contented least;
|
||||
Yet in these thoughts myself almost despising,
|
||||
Haply I think on thee, and then my state,
|
||||
Like to the lark at break of day arising
|
||||
From sullen earth, sings hymns at heaven's gate;
|
||||
For thy sweet love remember'd such wealth brings
|
||||
That then I scorn to change my state with kings.
|
||||
When to the sessions of sweet silent thought
|
||||
I summon up remembrance of things past,
|
||||
I sigh the lack of many a thing I sought,
|
||||
And with old woes new wail my dear time's waste:
|
||||
Then can I drown an eye, unused to flow,
|
||||
For precious friends hid in death's dateless night,
|
||||
And weep afresh love's long since cancell'd woe,
|
||||
And moan the expense of many a vanish'd sight:
|
||||
Then can I grieve at grievances foregone,
|
||||
And heavily from woe to woe tell o'er
|
||||
The sad account of fore-bemoaned moan,
|
||||
Which I new pay as if not paid before.
|
||||
But if the while I think on thee, dear friend,
|
||||
All losses are restored and sorrows end.
|
||||
Thy bosom is endeared with all hearts,
|
||||
Which I by lacking have supposed dead,
|
||||
And there reigns love and all love's loving parts,
|
||||
And all those friends which I thought buried.
|
||||
How many a holy and obsequious tear
|
||||
Hath dear religious love stol'n from mine eye
|
||||
As interest of the dead, which now appear
|
||||
But things removed that hidden in thee lie!
|
||||
Thou art the grave where buried love doth live,
|
||||
Hung with the trophies of my lovers gone,
|
||||
Who all their parts of me to thee did give;
|
||||
That due of many now is thine alone:
|
||||
Their images I loved I view in thee,
|
||||
And thou, all they, hast all the all of me.
|
||||
If thou survive my well-contented day,
|
||||
When that churl Death my bones with dust shall cover,
|
||||
And shalt by fortune once more re-survey
|
||||
These poor rude lines of thy deceased lover,
|
||||
Compare them with the bettering of the time,
|
||||
And though they be outstripp'd by every pen,
|
||||
Reserve them for my love, not for their rhyme,
|
||||
Exceeded by the height of happier men.
|
||||
O, then vouchsafe me but this loving thought:
|
||||
'Had my friend's Muse grown with this growing age,
|
||||
A dearer birth than this his love had brought,
|
||||
To march in ranks of better equipage:
|
||||
But since he died and poets better prove,
|
||||
Theirs for their style I'll read, his for his love.'
|
||||
Full many a glorious morning have I seen
|
||||
Flatter the mountain-tops with sovereign eye,
|
||||
Kissing with golden face the meadows green,
|
||||
Gilding pale streams with heavenly alchemy;
|
||||
Anon permit the basest clouds to ride
|
||||
With ugly rack on his celestial face,
|
||||
And from the forlorn world his visage hide,
|
||||
Stealing unseen to west with this disgrace:
|
||||
Even so my sun one early morn did shine
|
||||
With all triumphant splendor on my brow;
|
||||
But out, alack! he was but one hour mine;
|
||||
The region cloud hath mask'd him from me now.
|
||||
Yet him for this my love no whit disdaineth;
|
||||
Suns of the world may stain when heaven's sun staineth.
|
||||
Why didst thou promise such a beauteous day,
|
||||
And make me travel forth without my cloak,
|
||||
To let base clouds o'ertake me in my way,
|
||||
Hiding thy bravery in their rotten smoke?
|
||||
'Tis not enough that through the cloud thou break,
|
||||
To dry the rain on my storm-beaten face,
|
||||
For no man well of such a salve can speak
|
||||
That heals the wound and cures not the disgrace:
|
||||
Nor can thy shame give physic to my grief;
|
||||
Though thou repent, yet I have still the loss:
|
||||
The offender's sorrow lends but weak relief
|
||||
To him that bears the strong offence's cross.
|
||||
Ah! but those tears are pearl which thy love sheds,
|
||||
And they are rich and ransom all ill deeds.
|
||||
No more be grieved at that which thou hast done:
|
||||
Roses have thorns, and silver fountains mud;
|
||||
Clouds and eclipses stain both moon and sun,
|
||||
And loathsome canker lives in sweetest bud.
|
||||
All men make faults, and even I in this,
|
||||
Authorizing thy trespass with compare,
|
||||
Myself corrupting, salving thy amiss,
|
||||
Excusing thy sins more than thy sins are;
|
||||
For to thy sensual fault I bring in sense--
|
||||
Thy adverse party is thy advocate--
|
||||
And 'gainst myself a lawful plea commence:
|
||||
Such civil war is in my love and hate
|
||||
That I an accessary needs must be
|
||||
To that sweet thief which sourly robs from me.
|
||||
Let me confess that we two must be twain,
|
||||
Although our undivided loves are one:
|
||||
So shall those blots that do with me remain
|
||||
Without thy help by me be borne alone.
|
||||
In our two loves there is but one respect,
|
||||
Though in our lives a separable spite,
|
||||
Which though it alter not love's sole effect,
|
||||
Yet doth it steal sweet hours from love's delight.
|
||||
I may not evermore acknowledge thee,
|
||||
Lest my bewailed guilt should do thee shame,
|
||||
Nor thou with public kindness honour me,
|
||||
Unless thou take that honour from thy name:
|
||||
But do not so; I love thee in such sort
|
||||
As, thou being mine, mine is thy good report.
|
||||
As a decrepit father takes delight
|
||||
To see his active child do deeds of youth,
|
||||
So I, made lame by fortune's dearest spite,
|
||||
Take all my comfort of thy worth and truth.
|
||||
For whether beauty, birth, or wealth, or wit,
|
||||
Or any of these all, or all, or more,
|
||||
Entitled in thy parts do crowned sit,
|
||||
I make my love engrafted to this store:
|
||||
So then I am not lame, poor, nor despised,
|
||||
Whilst that this shadow doth such substance give
|
||||
That I in thy abundance am sufficed
|
||||
And by a part of all thy glory live.
|
||||
Look, what is best, that best I wish in thee:
|
||||
This wish I have; then ten times happy me!
|
73
cmake/hipify.py
Executable file
73
cmake/hipify.py
Executable file
@ -0,0 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
#
|
||||
# A command line tool for running pytorch's hipify preprocessor on CUDA
|
||||
# source files.
|
||||
#
|
||||
# See https://github.com/ROCm/hipify_torch
|
||||
# and <torch install dir>/utils/hipify/hipify_python.py
|
||||
#
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from torch.utils.hipify.hipify_python import hipify
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Project directory where all the source + include files live.
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--project_dir",
|
||||
help="The project directory.",
|
||||
)
|
||||
|
||||
# Directory where hipified files are written.
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
help="The output directory.",
|
||||
)
|
||||
|
||||
# Source files to convert.
|
||||
parser.add_argument("sources",
|
||||
help="Source files to hipify.",
|
||||
nargs="*",
|
||||
default=[])
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Limit include scope to project_dir only
|
||||
includes = [os.path.join(args.project_dir, '*')]
|
||||
|
||||
# Get absolute path for all source files.
|
||||
extra_files = [os.path.abspath(s) for s in args.sources]
|
||||
|
||||
# Copy sources from project directory to output directory.
|
||||
# The directory might already exist to hold object files so we ignore that.
|
||||
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
||||
|
||||
hipify_result = hipify(project_directory=args.project_dir,
|
||||
output_directory=args.output_dir,
|
||||
header_include_dirs=[],
|
||||
includes=includes,
|
||||
extra_files=extra_files,
|
||||
show_detailed=True,
|
||||
is_pytorch_extension=True,
|
||||
hipify_extra_files_only=True)
|
||||
|
||||
hipified_sources = []
|
||||
for source in args.sources:
|
||||
s_abs = os.path.abspath(source)
|
||||
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
||||
(s_abs in hipify_result
|
||||
and hipify_result[s_abs].hipified_path is not None)
|
||||
else s_abs)
|
||||
hipified_sources.append(hipified_s_abs)
|
||||
|
||||
assert (len(hipified_sources) == len(args.sources))
|
||||
|
||||
# Print hipified source files.
|
||||
print("\n".join(hipified_sources))
|
346
cmake/utils.cmake
Normal file
346
cmake/utils.cmake
Normal file
@ -0,0 +1,346 @@
|
||||
#
|
||||
# Attempt to find the python package that uses the same python executable as
|
||||
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
|
||||
#
|
||||
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
||||
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
||||
set(Python_EXECUTABLE ${EXECUTABLE})
|
||||
find_package(Python COMPONENTS Interpreter Development.Module)
|
||||
if (NOT Python_FOUND)
|
||||
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
||||
endif()
|
||||
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
|
||||
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
|
||||
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
|
||||
message(FATAL_ERROR
|
||||
"Python version (${_VER}) is not one of the supported versions: "
|
||||
"${_SUPPORTED_VERSIONS_LIST}.")
|
||||
endif()
|
||||
message(STATUS "Found python matching: ${EXECUTABLE}.")
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
|
||||
# has trailing whitespace stripped. If an error is encountered when running
|
||||
# python, a fatal message `ERR_MSG` is issued.
|
||||
#
|
||||
function (run_python OUT EXPR ERR_MSG)
|
||||
execute_process(
|
||||
COMMAND
|
||||
"${Python_EXECUTABLE}" "-c" "${EXPR}"
|
||||
OUTPUT_VARIABLE PYTHON_OUT
|
||||
RESULT_VARIABLE PYTHON_ERROR_CODE
|
||||
ERROR_VARIABLE PYTHON_STDERR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(NOT PYTHON_ERROR_CODE EQUAL 0)
|
||||
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
|
||||
endif()
|
||||
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
|
||||
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
|
||||
macro (append_cmake_prefix_path PKG EXPR)
|
||||
run_python(_PREFIX_PATH
|
||||
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
|
||||
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
|
||||
# of CUDA source files. The names of the corresponding "hipified" sources are
|
||||
# stored in `OUT_SRCS`.
|
||||
#
|
||||
function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
||||
#
|
||||
# Split into C++ and non-C++ (i.e. CUDA) sources.
|
||||
#
|
||||
set(SRCS ${ORIG_SRCS})
|
||||
set(CXX_SRCS ${ORIG_SRCS})
|
||||
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
|
||||
#
|
||||
# Generate ROCm/HIP source file names from CUDA file names.
|
||||
# Since HIP files are generated code, they will appear in the build area
|
||||
# `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
|
||||
#
|
||||
set(HIP_SRCS)
|
||||
foreach (SRC ${SRCS})
|
||||
string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
|
||||
string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
|
||||
list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
|
||||
endforeach()
|
||||
|
||||
set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
|
||||
add_custom_target(
|
||||
hipify${NAME}
|
||||
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
|
||||
BYPRODUCTS ${HIP_SRCS}
|
||||
COMMENT "Running hipify on ${NAME} extension source files.")
|
||||
|
||||
# Swap out original extension sources with hipified sources.
|
||||
list(APPEND HIP_SRCS ${CXX_SRCS})
|
||||
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
#
|
||||
# Get additional GPU compiler flags from torch.
|
||||
#
|
||||
function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
if (${GPU_LANG} STREQUAL "CUDA")
|
||||
#
|
||||
# Get common NVCC flags from torch.
|
||||
#
|
||||
run_python(GPU_FLAGS
|
||||
"from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
|
||||
"Failed to determine torch nvcc compiler flags")
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
|
||||
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
|
||||
endif()
|
||||
|
||||
elseif(${GPU_LANG} STREQUAL "HIP")
|
||||
#
|
||||
# Get common HIP/HIPCC flags from torch.
|
||||
#
|
||||
run_python(GPU_FLAGS
|
||||
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
|
||||
"Failed to determine torch nvcc compiler flags")
|
||||
|
||||
list(APPEND GPU_FLAGS
|
||||
"-DUSE_ROCM"
|
||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||
"-U__HIP_NO_HALF_OPERATORS__"
|
||||
"-fno-gpu-rdc")
|
||||
|
||||
endif()
|
||||
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Macro for converting a `gencode` version number to a cmake version number.
|
||||
macro(string_to_ver OUT_VER IN_STR)
|
||||
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
||||
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
|
||||
# `GPU_ARCHES`.
|
||||
#
|
||||
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
|
||||
#
|
||||
macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
||||
set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
|
||||
message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
|
||||
|
||||
if (${GPU_LANG} STREQUAL "HIP")
|
||||
#
|
||||
# `GPU_ARCHES` controls the `--offload-arch` flags.
|
||||
# `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
|
||||
# via the `PYTORCH_ROCM_ARCH` env variable.
|
||||
#
|
||||
|
||||
#
|
||||
# Find the intersection of the supported + detected architectures to
|
||||
# set the module architecture flags.
|
||||
#
|
||||
set(${GPU_ARCHES})
|
||||
foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES})
|
||||
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
||||
list(APPEND ${GPU_ARCHES} ${_ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(NOT ${GPU_ARCHES})
|
||||
message(FATAL_ERROR
|
||||
"None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
|
||||
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
|
||||
endif()
|
||||
|
||||
elseif(${GPU_LANG} STREQUAL "CUDA")
|
||||
#
|
||||
# Setup/process CUDA arch flags.
|
||||
#
|
||||
# The torch cmake setup hardcodes the detected architecture flags in
|
||||
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
|
||||
# can't modified on a per-target basis, e.g. for the `punica` extension.
|
||||
# So, all the `-gencode` flags need to be extracted and removed from
|
||||
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
|
||||
# Since it's not possible to use `target_compiler_options` for adding target
|
||||
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
|
||||
# must be used instead. This requires repackaging the architecture flags
|
||||
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
|
||||
#
|
||||
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
|
||||
# to one of the other nvcc options to specify architectures.
|
||||
#
|
||||
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
|
||||
# detected architectures.
|
||||
#
|
||||
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
||||
|
||||
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
|
||||
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
|
||||
# and passed back via the `CUDA_ARCHITECTURES` property.
|
||||
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
|
||||
${CMAKE_CUDA_FLAGS})
|
||||
|
||||
# If this error is triggered, it might mean that torch has changed how it sets
|
||||
# up nvcc architecture code generation flags.
|
||||
if (NOT _CUDA_ARCH_FLAGS)
|
||||
message(FATAL_ERROR
|
||||
"Could not find any architecture related code generation flags in "
|
||||
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
|
||||
endif()
|
||||
|
||||
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
|
||||
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
|
||||
|
||||
# Initialize the architecture lists to empty.
|
||||
set(${GPU_ARCHES})
|
||||
|
||||
# Process each `gencode` flag.
|
||||
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
|
||||
# For each flag, extract the version number and whether it refers to PTX
|
||||
# or native code.
|
||||
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
|
||||
# for that match.
|
||||
|
||||
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
|
||||
if (_COMPUTE)
|
||||
set(_COMPUTE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
|
||||
if (_SM)
|
||||
set(_SM ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
|
||||
if (_CODE)
|
||||
set(_CODE ${CMAKE_MATCH_1})
|
||||
endif()
|
||||
|
||||
# Make sure the virtual architecture can be matched.
|
||||
if (NOT _COMPUTE)
|
||||
message(FATAL_ERROR
|
||||
"Could not determine virtual architecture from: ${_ARCH}.")
|
||||
endif()
|
||||
|
||||
# One of sm_ or compute_ must exist.
|
||||
if ((NOT _SM) AND (NOT _CODE))
|
||||
message(FATAL_ERROR
|
||||
"Could not determine a codegen architecture from: ${_ARCH}.")
|
||||
endif()
|
||||
|
||||
if (_SM)
|
||||
# -real suffix let CMake to only generate elf code for the kernels.
|
||||
# we want this, otherwise the added ptx (default) will increase binary size.
|
||||
set(_VIRT "-real")
|
||||
set(_CODE_ARCH ${_SM})
|
||||
else()
|
||||
# -virtual suffix let CMake to generate ptx code for the kernels.
|
||||
set(_VIRT "-virtual")
|
||||
set(_CODE_ARCH ${_CODE})
|
||||
endif()
|
||||
|
||||
# Check if the current version is in the supported arch list.
|
||||
string_to_ver(_CODE_VER ${_CODE_ARCH})
|
||||
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
|
||||
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
|
||||
continue()
|
||||
endif()
|
||||
|
||||
# Add it to the arch list.
|
||||
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
|
||||
endforeach()
|
||||
endif()
|
||||
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Define a target named `GPU_MOD_NAME` for a single extension. The
|
||||
# arguments are:
|
||||
#
|
||||
# DESTINATION <dest> - Module destination directory.
|
||||
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
|
||||
# etc.
|
||||
# SOURCES <sources> - List of source files relative to CMakeLists.txt
|
||||
# directory.
|
||||
#
|
||||
# Optional arguments:
|
||||
#
|
||||
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
|
||||
# format.
|
||||
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
|
||||
# and `CMAKE_HIP_ARCHITECTURES` for more info.
|
||||
# ARCHITECTURES will use cmake's defaults if
|
||||
# not provided.
|
||||
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
|
||||
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
||||
# LIBRARIES <libraries> - Extra link libraries.
|
||||
# WITH_SOABI - Generate library with python SOABI suffix name.
|
||||
#
|
||||
# Note: optimization level/debug info is set via cmake build type.
|
||||
#
|
||||
function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
cmake_parse_arguments(PARSE_ARGV 1
|
||||
GPU
|
||||
"WITH_SOABI"
|
||||
"DESTINATION;LANGUAGE"
|
||||
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
||||
|
||||
# Add hipify preprocessing step when building with HIP/ROCm.
|
||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
|
||||
endif()
|
||||
|
||||
if (GPU_WITH_SOABI)
|
||||
set(GPU_WITH_SOABI WITH_SOABI)
|
||||
else()
|
||||
set(GPU_WITH_SOABI)
|
||||
endif()
|
||||
|
||||
Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
|
||||
|
||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||
# Make this target dependent on the hipify preprocessor step.
|
||||
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
|
||||
endif()
|
||||
|
||||
if (GPU_ARCHITECTURES)
|
||||
set_target_properties(${GPU_MOD_NAME} PROPERTIES
|
||||
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
|
||||
|
||||
target_compile_options(${GPU_MOD_NAME} PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
|
||||
|
||||
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
|
||||
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
|
||||
|
||||
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
||||
${GPU_INCLUDE_DIRECTORIES})
|
||||
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
|
||||
${GPU_LIBRARIES})
|
||||
|
||||
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
||||
# dependencies that are not necessary and may not be installed.
|
||||
if (GPU_LANGUAGE STREQUAL "CUDA")
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
|
||||
${CUDA_LIBRARIES})
|
||||
else()
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
||||
endif()
|
||||
|
||||
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
|
||||
endfunction()
|
719
collect_env.py
Normal file
719
collect_env.py
Normal file
@ -0,0 +1,719 @@
|
||||
# ruff: noqa
|
||||
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
|
||||
|
||||
# Unlike the rest of the PyTorch this file must be python2 compliant.
|
||||
# This script outputs relevant system environment info
|
||||
# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
|
||||
import datetime
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except (ImportError, NameError, AttributeError, OSError):
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
# System Environment Information
|
||||
SystemEnv = namedtuple(
|
||||
'SystemEnv',
|
||||
[
|
||||
'torch_version',
|
||||
'is_debug_build',
|
||||
'cuda_compiled_version',
|
||||
'gcc_version',
|
||||
'clang_version',
|
||||
'cmake_version',
|
||||
'os',
|
||||
'libc_version',
|
||||
'python_version',
|
||||
'python_platform',
|
||||
'is_cuda_available',
|
||||
'cuda_runtime_version',
|
||||
'cuda_module_loading',
|
||||
'nvidia_driver_version',
|
||||
'nvidia_gpu_models',
|
||||
'cudnn_version',
|
||||
'pip_version', # 'pip' or 'pip3'
|
||||
'pip_packages',
|
||||
'conda_packages',
|
||||
'hip_compiled_version',
|
||||
'hip_runtime_version',
|
||||
'miopen_runtime_version',
|
||||
'caching_allocator_config',
|
||||
'is_xnnpack_available',
|
||||
'cpu_info',
|
||||
'rocm_version', # vllm specific field
|
||||
'neuron_sdk_version', # vllm specific field
|
||||
'vllm_version', # vllm specific field
|
||||
'vllm_build_flags', # vllm specific field
|
||||
'gpu_topo', # vllm specific field
|
||||
])
|
||||
|
||||
DEFAULT_CONDA_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"cudatoolkit",
|
||||
"soumith",
|
||||
"mkl",
|
||||
"magma",
|
||||
"triton",
|
||||
"optree",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
"torch",
|
||||
"numpy",
|
||||
"mypy",
|
||||
"flake8",
|
||||
"triton",
|
||||
"optree",
|
||||
"onnx",
|
||||
}
|
||||
|
||||
|
||||
def run(command):
|
||||
"""Return (return-code, stdout, stderr)."""
|
||||
shell = True if type(command) is str else False
|
||||
p = subprocess.Popen(command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
shell=shell)
|
||||
raw_output, raw_err = p.communicate()
|
||||
rc = p.returncode
|
||||
if get_platform() == 'win32':
|
||||
enc = 'oem'
|
||||
else:
|
||||
enc = locale.getpreferredencoding()
|
||||
output = raw_output.decode(enc)
|
||||
err = raw_err.decode(enc)
|
||||
return rc, output.strip(), err.strip()
|
||||
|
||||
|
||||
def run_and_read_all(run_lambda, command):
|
||||
"""Run command using run_lambda; reads and returns entire output if rc is 0."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out
|
||||
|
||||
|
||||
def run_and_parse_first_match(run_lambda, command, regex):
|
||||
"""Run command using run_lambda, returns the first regex match if it exists."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
match = re.search(regex, out)
|
||||
if match is None:
|
||||
return None
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def run_and_return_first_line(run_lambda, command):
|
||||
"""Run command using run_lambda and returns first line if output is not empty."""
|
||||
rc, out, _ = run_lambda(command)
|
||||
if rc != 0:
|
||||
return None
|
||||
return out.split('\n')[0]
|
||||
|
||||
|
||||
def get_conda_packages(run_lambda, patterns=None):
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_CONDA_PATTERNS
|
||||
conda = os.environ.get('CONDA_EXE', 'conda')
|
||||
out = run_and_read_all(run_lambda, "{} list".format(conda))
|
||||
if out is None:
|
||||
return out
|
||||
|
||||
return "\n".join(line for line in out.splitlines()
|
||||
if not line.startswith("#") and any(name in line
|
||||
for name in patterns))
|
||||
|
||||
|
||||
def get_gcc_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
|
||||
|
||||
|
||||
def get_clang_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'clang --version',
|
||||
r'clang version (.*)')
|
||||
|
||||
|
||||
def get_cmake_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'cmake --version',
|
||||
r'cmake (.*)')
|
||||
|
||||
|
||||
def get_nvidia_driver_version(run_lambda):
|
||||
if get_platform() == 'darwin':
|
||||
cmd = 'kextstat | grep -i cuda'
|
||||
return run_and_parse_first_match(run_lambda, cmd,
|
||||
r'com[.]nvidia[.]CUDA [(](.*?)[)]')
|
||||
smi = get_nvidia_smi()
|
||||
return run_and_parse_first_match(run_lambda, smi,
|
||||
r'Driver Version: (.*?) ')
|
||||
|
||||
|
||||
def get_gpu_info(run_lambda):
|
||||
if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(
|
||||
torch.version, 'hip') and torch.version.hip is not None):
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
if torch.version.hip is not None:
|
||||
prop = torch.cuda.get_device_properties(0)
|
||||
if hasattr(prop, "gcnArchName"):
|
||||
gcnArch = " ({})".format(prop.gcnArchName)
|
||||
else:
|
||||
gcnArch = "NoGCNArchNameOnOldPyTorch"
|
||||
else:
|
||||
gcnArch = ""
|
||||
return torch.cuda.get_device_name(None) + gcnArch
|
||||
return None
|
||||
smi = get_nvidia_smi()
|
||||
uuid_regex = re.compile(r' \(UUID: .+?\)')
|
||||
rc, out, _ = run_lambda(smi + ' -L')
|
||||
if rc != 0:
|
||||
return None
|
||||
# Anonymize GPUs by removing their UUID
|
||||
return re.sub(uuid_regex, '', out)
|
||||
|
||||
|
||||
def get_running_cuda_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'nvcc --version',
|
||||
r'release .+ V(.*)')
|
||||
|
||||
|
||||
def get_cudnn_version(run_lambda):
|
||||
"""Return a list of libcudnn.so; it's hard to tell which one is being used."""
|
||||
if get_platform() == 'win32':
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%")
|
||||
where_cmd = os.path.join(system_root, 'System32', 'where')
|
||||
cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
|
||||
elif get_platform() == 'darwin':
|
||||
# CUDA libraries and drivers can be found in /usr/local/cuda/. See
|
||||
# https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
|
||||
# https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
|
||||
# Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
|
||||
cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*'
|
||||
else:
|
||||
cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
|
||||
rc, out, _ = run_lambda(cudnn_cmd)
|
||||
# find will return 1 if there are permission errors or if not found
|
||||
if len(out) == 0 or (rc != 1 and rc != 0):
|
||||
l = os.environ.get('CUDNN_LIBRARY')
|
||||
if l is not None and os.path.isfile(l):
|
||||
return os.path.realpath(l)
|
||||
return None
|
||||
files_set = set()
|
||||
for fn in out.split('\n'):
|
||||
fn = os.path.realpath(fn) # eliminate symbolic links
|
||||
if os.path.isfile(fn):
|
||||
files_set.add(fn)
|
||||
if not files_set:
|
||||
return None
|
||||
# Alphabetize the result because the order is non-deterministic otherwise
|
||||
files = sorted(files_set)
|
||||
if len(files) == 1:
|
||||
return files[0]
|
||||
result = '\n'.join(files)
|
||||
return 'Probably one of the following:\n{}'.format(result)
|
||||
|
||||
|
||||
def get_nvidia_smi():
|
||||
# Note: nvidia-smi is currently available only on Windows and Linux
|
||||
smi = 'nvidia-smi'
|
||||
if get_platform() == 'win32':
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
program_files_root = os.environ.get('PROGRAMFILES',
|
||||
'C:\\Program Files')
|
||||
legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation',
|
||||
'NVSMI', smi)
|
||||
new_path = os.path.join(system_root, 'System32', smi)
|
||||
smis = [new_path, legacy_path]
|
||||
for candidate_smi in smis:
|
||||
if os.path.exists(candidate_smi):
|
||||
smi = '"{}"'.format(candidate_smi)
|
||||
break
|
||||
return smi
|
||||
|
||||
|
||||
def get_rocm_version(run_lambda):
|
||||
"""Returns the ROCm version if available, otherwise 'N/A'."""
|
||||
return run_and_parse_first_match(run_lambda, 'hipcc --version',
|
||||
r'HIP version: (\S+)')
|
||||
|
||||
|
||||
def get_neuron_sdk_version(run_lambda):
|
||||
# Adapted from your install script
|
||||
try:
|
||||
result = run_lambda(["neuron-ls"])
|
||||
return result if result[0] == 0 else 'N/A'
|
||||
except Exception:
|
||||
return 'N/A'
|
||||
|
||||
|
||||
def get_vllm_version():
|
||||
try:
|
||||
import vllm
|
||||
return vllm.__version__
|
||||
except ImportError:
|
||||
return 'N/A'
|
||||
|
||||
|
||||
def summarize_vllm_build_flags():
|
||||
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
||||
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
|
||||
os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'),
|
||||
'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled',
|
||||
'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled',
|
||||
)
|
||||
|
||||
|
||||
def get_gpu_topo(run_lambda):
|
||||
if get_platform() == 'linux':
|
||||
return run_and_read_all(run_lambda, 'nvidia-smi topo -m')
|
||||
return None
|
||||
|
||||
|
||||
# example outputs of CPU infos
|
||||
# * linux
|
||||
# Architecture: x86_64
|
||||
# CPU op-mode(s): 32-bit, 64-bit
|
||||
# Address sizes: 46 bits physical, 48 bits virtual
|
||||
# Byte Order: Little Endian
|
||||
# CPU(s): 128
|
||||
# On-line CPU(s) list: 0-127
|
||||
# Vendor ID: GenuineIntel
|
||||
# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# CPU family: 6
|
||||
# Model: 106
|
||||
# Thread(s) per core: 2
|
||||
# Core(s) per socket: 32
|
||||
# Socket(s): 2
|
||||
# Stepping: 6
|
||||
# BogoMIPS: 5799.78
|
||||
# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
|
||||
# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
|
||||
# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
|
||||
# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
|
||||
# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
|
||||
# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
|
||||
# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
|
||||
# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
|
||||
# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
|
||||
# Virtualization features:
|
||||
# Hypervisor vendor: KVM
|
||||
# Virtualization type: full
|
||||
# Caches (sum of all):
|
||||
# L1d: 3 MiB (64 instances)
|
||||
# L1i: 2 MiB (64 instances)
|
||||
# L2: 80 MiB (64 instances)
|
||||
# L3: 108 MiB (2 instances)
|
||||
# NUMA:
|
||||
# NUMA node(s): 2
|
||||
# NUMA node0 CPU(s): 0-31,64-95
|
||||
# NUMA node1 CPU(s): 32-63,96-127
|
||||
# Vulnerabilities:
|
||||
# Itlb multihit: Not affected
|
||||
# L1tf: Not affected
|
||||
# Mds: Not affected
|
||||
# Meltdown: Not affected
|
||||
# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
|
||||
# Retbleed: Not affected
|
||||
# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
|
||||
# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
|
||||
# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
|
||||
# Srbds: Not affected
|
||||
# Tsx async abort: Not affected
|
||||
# * win32
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU0
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
#
|
||||
# Architecture=9
|
||||
# CurrentClockSpeed=2900
|
||||
# DeviceID=CPU1
|
||||
# Family=179
|
||||
# L2CacheSize=40960
|
||||
# L2CacheSpeed=
|
||||
# Manufacturer=GenuineIntel
|
||||
# MaxClockSpeed=2900
|
||||
# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
|
||||
# ProcessorType=3
|
||||
# Revision=27142
|
||||
|
||||
|
||||
def get_cpu_info(run_lambda):
|
||||
rc, out, err = 0, '', ''
|
||||
if get_platform() == 'linux':
|
||||
rc, out, err = run_lambda('lscpu')
|
||||
elif get_platform() == 'win32':
|
||||
rc, out, err = run_lambda(
|
||||
'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
|
||||
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE'
|
||||
)
|
||||
elif get_platform() == 'darwin':
|
||||
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
|
||||
cpu_info = 'None'
|
||||
if rc == 0:
|
||||
cpu_info = out
|
||||
else:
|
||||
cpu_info = err
|
||||
return cpu_info
|
||||
|
||||
|
||||
def get_platform():
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'linux'
|
||||
elif sys.platform.startswith('win32'):
|
||||
return 'win32'
|
||||
elif sys.platform.startswith('cygwin'):
|
||||
return 'cygwin'
|
||||
elif sys.platform.startswith('darwin'):
|
||||
return 'darwin'
|
||||
else:
|
||||
return sys.platform
|
||||
|
||||
|
||||
def get_mac_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion',
|
||||
r'(.*)')
|
||||
|
||||
|
||||
def get_windows_version(run_lambda):
|
||||
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
|
||||
wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic')
|
||||
findstr_cmd = os.path.join(system_root, 'System32', 'findstr')
|
||||
return run_and_read_all(
|
||||
run_lambda,
|
||||
'{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
|
||||
|
||||
|
||||
def get_lsb_version(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'lsb_release -a',
|
||||
r'Description:\t(.*)')
|
||||
|
||||
|
||||
def check_release_file(run_lambda):
|
||||
return run_and_parse_first_match(run_lambda, 'cat /etc/*-release',
|
||||
r'PRETTY_NAME="(.*)"')
|
||||
|
||||
|
||||
def get_os(run_lambda):
|
||||
from platform import machine
|
||||
platform = get_platform()
|
||||
|
||||
if platform == 'win32' or platform == 'cygwin':
|
||||
return get_windows_version(run_lambda)
|
||||
|
||||
if platform == 'darwin':
|
||||
version = get_mac_version(run_lambda)
|
||||
if version is None:
|
||||
return None
|
||||
return 'macOS {} ({})'.format(version, machine())
|
||||
|
||||
if platform == 'linux':
|
||||
# Ubuntu/Debian based
|
||||
desc = get_lsb_version(run_lambda)
|
||||
if desc is not None:
|
||||
return '{} ({})'.format(desc, machine())
|
||||
|
||||
# Try reading /etc/*-release
|
||||
desc = check_release_file(run_lambda)
|
||||
if desc is not None:
|
||||
return '{} ({})'.format(desc, machine())
|
||||
|
||||
return '{} ({})'.format(platform, machine())
|
||||
|
||||
# Unknown platform
|
||||
return platform
|
||||
|
||||
|
||||
def get_python_platform():
|
||||
import platform
|
||||
return platform.platform()
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
import platform
|
||||
if get_platform() != 'linux':
|
||||
return 'N/A'
|
||||
return '-'.join(platform.libc_ver())
|
||||
|
||||
|
||||
def get_pip_packages(run_lambda, patterns=None):
|
||||
"""Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
|
||||
if patterns is None:
|
||||
patterns = DEFAULT_PIP_PATTERNS
|
||||
|
||||
# People generally have `pip` as `pip` or `pip3`
|
||||
# But here it is invoked as `python -mpip`
|
||||
def run_with_pip(pip):
|
||||
out = run_and_read_all(run_lambda, pip + ["list", "--format=freeze"])
|
||||
return "\n".join(line for line in out.splitlines()
|
||||
if any(name in line for name in patterns))
|
||||
|
||||
pip_version = 'pip3' if sys.version[0] == '3' else 'pip'
|
||||
out = run_with_pip([sys.executable, '-mpip'])
|
||||
|
||||
return pip_version, out
|
||||
|
||||
|
||||
def get_cachingallocator_config():
|
||||
ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
|
||||
return ca_config
|
||||
|
||||
|
||||
def get_cuda_module_loading_config():
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
config = os.environ.get('CUDA_MODULE_LOADING', '')
|
||||
return config
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def is_xnnpack_available():
|
||||
if TORCH_AVAILABLE:
|
||||
import torch.backends.xnnpack
|
||||
return str(
|
||||
torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
|
||||
else:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def get_env_info():
|
||||
run_lambda = run
|
||||
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
||||
|
||||
if TORCH_AVAILABLE:
|
||||
version_str = torch.__version__
|
||||
debug_mode_str = str(torch.version.debug)
|
||||
cuda_available_str = str(torch.cuda.is_available())
|
||||
cuda_version_str = torch.version.cuda
|
||||
if not hasattr(torch.version,
|
||||
'hip') or torch.version.hip is None: # cuda version
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
|
||||
else: # HIP version
|
||||
|
||||
def get_version_or_na(cfg, prefix):
|
||||
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
|
||||
return _lst[0] if _lst else 'N/A'
|
||||
|
||||
cfg = torch._C._show_config().split('\n')
|
||||
hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime')
|
||||
miopen_runtime_version = get_version_or_na(cfg, 'MIOpen')
|
||||
cuda_version_str = 'N/A'
|
||||
hip_compiled_version = torch.version.hip
|
||||
else:
|
||||
version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A'
|
||||
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
|
||||
|
||||
sys_version = sys.version.replace("\n", " ")
|
||||
|
||||
conda_packages = get_conda_packages(run_lambda)
|
||||
|
||||
rocm_version = get_rocm_version(run_lambda)
|
||||
neuron_sdk_version = get_neuron_sdk_version(run_lambda)
|
||||
vllm_version = get_vllm_version()
|
||||
vllm_build_flags = summarize_vllm_build_flags()
|
||||
gpu_topo = get_gpu_topo(run_lambda)
|
||||
|
||||
return SystemEnv(
|
||||
torch_version=version_str,
|
||||
is_debug_build=debug_mode_str,
|
||||
python_version='{} ({}-bit runtime)'.format(
|
||||
sys_version,
|
||||
sys.maxsize.bit_length() + 1),
|
||||
python_platform=get_python_platform(),
|
||||
is_cuda_available=cuda_available_str,
|
||||
cuda_compiled_version=cuda_version_str,
|
||||
cuda_runtime_version=get_running_cuda_version(run_lambda),
|
||||
cuda_module_loading=get_cuda_module_loading_config(),
|
||||
nvidia_gpu_models=get_gpu_info(run_lambda),
|
||||
nvidia_driver_version=get_nvidia_driver_version(run_lambda),
|
||||
cudnn_version=get_cudnn_version(run_lambda),
|
||||
hip_compiled_version=hip_compiled_version,
|
||||
hip_runtime_version=hip_runtime_version,
|
||||
miopen_runtime_version=miopen_runtime_version,
|
||||
pip_version=pip_version,
|
||||
pip_packages=pip_list_output,
|
||||
conda_packages=conda_packages,
|
||||
os=get_os(run_lambda),
|
||||
libc_version=get_libc_version(),
|
||||
gcc_version=get_gcc_version(run_lambda),
|
||||
clang_version=get_clang_version(run_lambda),
|
||||
cmake_version=get_cmake_version(run_lambda),
|
||||
caching_allocator_config=get_cachingallocator_config(),
|
||||
is_xnnpack_available=is_xnnpack_available(),
|
||||
cpu_info=get_cpu_info(run_lambda),
|
||||
rocm_version=rocm_version,
|
||||
neuron_sdk_version=neuron_sdk_version,
|
||||
vllm_version=vllm_version,
|
||||
vllm_build_flags=vllm_build_flags,
|
||||
gpu_topo=gpu_topo,
|
||||
)
|
||||
|
||||
|
||||
env_info_fmt = """
|
||||
PyTorch version: {torch_version}
|
||||
Is debug build: {is_debug_build}
|
||||
CUDA used to build PyTorch: {cuda_compiled_version}
|
||||
ROCM used to build PyTorch: {hip_compiled_version}
|
||||
|
||||
OS: {os}
|
||||
GCC version: {gcc_version}
|
||||
Clang version: {clang_version}
|
||||
CMake version: {cmake_version}
|
||||
Libc version: {libc_version}
|
||||
|
||||
Python version: {python_version}
|
||||
Python platform: {python_platform}
|
||||
Is CUDA available: {is_cuda_available}
|
||||
CUDA runtime version: {cuda_runtime_version}
|
||||
CUDA_MODULE_LOADING set to: {cuda_module_loading}
|
||||
GPU models and configuration: {nvidia_gpu_models}
|
||||
Nvidia driver version: {nvidia_driver_version}
|
||||
cuDNN version: {cudnn_version}
|
||||
HIP runtime version: {hip_runtime_version}
|
||||
MIOpen runtime version: {miopen_runtime_version}
|
||||
Is XNNPACK available: {is_xnnpack_available}
|
||||
|
||||
CPU:
|
||||
{cpu_info}
|
||||
|
||||
Versions of relevant libraries:
|
||||
{pip_packages}
|
||||
{conda_packages}
|
||||
""".strip()
|
||||
|
||||
env_info_fmt += """
|
||||
ROCM Version: {rocm_version}
|
||||
Neuron SDK Version: {neuron_sdk_version}
|
||||
vLLM Version: {vllm_version}
|
||||
vLLM Build Flags:
|
||||
{vllm_build_flags}
|
||||
GPU Topology:
|
||||
{gpu_topo}
|
||||
""".strip()
|
||||
|
||||
|
||||
def pretty_str(envinfo):
|
||||
|
||||
def replace_nones(dct, replacement='Could not collect'):
|
||||
for key in dct.keys():
|
||||
if dct[key] is not None:
|
||||
continue
|
||||
dct[key] = replacement
|
||||
return dct
|
||||
|
||||
def replace_bools(dct, true='Yes', false='No'):
|
||||
for key in dct.keys():
|
||||
if dct[key] is True:
|
||||
dct[key] = true
|
||||
elif dct[key] is False:
|
||||
dct[key] = false
|
||||
return dct
|
||||
|
||||
def prepend(text, tag='[prepend]'):
|
||||
lines = text.split('\n')
|
||||
updated_lines = [tag + line for line in lines]
|
||||
return '\n'.join(updated_lines)
|
||||
|
||||
def replace_if_empty(text, replacement='No relevant packages'):
|
||||
if text is not None and len(text) == 0:
|
||||
return replacement
|
||||
return text
|
||||
|
||||
def maybe_start_on_next_line(string):
|
||||
# If `string` is multiline, prepend a \n to it.
|
||||
if string is not None and len(string.split('\n')) > 1:
|
||||
return '\n{}\n'.format(string)
|
||||
return string
|
||||
|
||||
mutable_dict = envinfo._asdict()
|
||||
|
||||
# If nvidia_gpu_models is multiline, start on the next line
|
||||
mutable_dict['nvidia_gpu_models'] = \
|
||||
maybe_start_on_next_line(envinfo.nvidia_gpu_models)
|
||||
|
||||
# If the machine doesn't have CUDA, report some fields as 'No CUDA'
|
||||
dynamic_cuda_fields = [
|
||||
'cuda_runtime_version',
|
||||
'nvidia_gpu_models',
|
||||
'nvidia_driver_version',
|
||||
]
|
||||
all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
|
||||
all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None
|
||||
for field in dynamic_cuda_fields)
|
||||
if TORCH_AVAILABLE and not torch.cuda.is_available(
|
||||
) and all_dynamic_cuda_fields_missing:
|
||||
for field in all_cuda_fields:
|
||||
mutable_dict[field] = 'No CUDA'
|
||||
if envinfo.cuda_compiled_version is None:
|
||||
mutable_dict['cuda_compiled_version'] = 'None'
|
||||
|
||||
# Replace True with Yes, False with No
|
||||
mutable_dict = replace_bools(mutable_dict)
|
||||
|
||||
# Replace all None objects with 'Could not collect'
|
||||
mutable_dict = replace_nones(mutable_dict)
|
||||
|
||||
# If either of these are '', replace with 'No relevant packages'
|
||||
mutable_dict['pip_packages'] = replace_if_empty(
|
||||
mutable_dict['pip_packages'])
|
||||
mutable_dict['conda_packages'] = replace_if_empty(
|
||||
mutable_dict['conda_packages'])
|
||||
|
||||
# Tag conda and pip packages with a prefix
|
||||
# If they were previously None, they'll show up as ie '[conda] Could not collect'
|
||||
if mutable_dict['pip_packages']:
|
||||
mutable_dict['pip_packages'] = prepend(
|
||||
mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version))
|
||||
if mutable_dict['conda_packages']:
|
||||
mutable_dict['conda_packages'] = prepend(
|
||||
mutable_dict['conda_packages'], '[conda] ')
|
||||
mutable_dict['cpu_info'] = envinfo.cpu_info
|
||||
return env_info_fmt.format(**mutable_dict)
|
||||
|
||||
|
||||
def get_pretty_env_info():
|
||||
return pretty_str(get_env_info())
|
||||
|
||||
|
||||
def main():
|
||||
print("Collecting environment information...")
|
||||
output = get_pretty_env_info()
|
||||
print(output)
|
||||
|
||||
if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(
|
||||
torch.utils, '_crash_handler'):
|
||||
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
|
||||
if sys.platform == "linux" and os.path.exists(minidump_dir):
|
||||
dumps = [
|
||||
os.path.join(minidump_dir, dump)
|
||||
for dump in os.listdir(minidump_dir)
|
||||
]
|
||||
latest = max(dumps, key=os.path.getctime)
|
||||
ctime = os.path.getctime(latest)
|
||||
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
|
||||
'%Y-%m-%d %H:%M:%S')
|
||||
msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
|
||||
"if this is related to your bug please include it when you file a report ***"
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,28 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
m.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
m.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
}
|
@ -1,50 +1,96 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void silu_and_mul_kernel(
|
||||
// Activation and gating kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x) * y;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||
// x * sigmoid(x)
|
||||
return (T) (((float) x) / (1.0f + expf((float) -x)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'none' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
|
||||
const float f = (float) x;
|
||||
constexpr float ALPHA = M_SQRT1_2;
|
||||
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
// Equivalent to PyTorch GELU with 'tanh' approximation.
|
||||
// Refer to:
|
||||
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
|
||||
const float f = (float) x;
|
||||
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
|
||||
constexpr float KAPPA = 0.044715;
|
||||
float x_cube = f * f * f;
|
||||
float inner = BETA * (f + KAPPA * x_cube);
|
||||
return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation and gating kernel.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"act_and_mul_kernel", \
|
||||
[&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
}
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
d);
|
||||
});
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@ -57,7 +103,7 @@ __global__ void activation_kernel(
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
@ -70,6 +116,7 @@ __global__ void activation_kernel(
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
|
@ -1,42 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
m.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
}
|
@ -4,3 +4,4 @@
|
||||
#include "dtype_float16.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
#include "dtype_bfloat16.cuh"
|
||||
#include "dtype_fp8_e5m2.cuh"
|
||||
|
@ -15,15 +15,19 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#define WARP_SIZE 32
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -40,7 +44,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
@ -59,29 +63,31 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
return VLLM_SHFL_SYNC(sum, 0);
|
||||
}
|
||||
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||
__device__ void paged_attention_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -124,7 +130,8 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
const int head_idx = blockIdx.x;
|
||||
const int num_heads = gridDim.x;
|
||||
const int kv_head_idx = head_mapping[head_idx];
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
@ -135,6 +142,9 @@ __device__ void paged_attention_kernel(
|
||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||
#endif
|
||||
|
||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||
@ -166,7 +176,7 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||
// Each thread group fetches x elements from the key at a time.
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
constexpr int x = 16 / sizeof(cache_t);
|
||||
float qk_max = -FLT_MAX;
|
||||
|
||||
// Iterate over the key blocks.
|
||||
@ -192,14 +202,24 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
// Vector conversion from Quant_vec to K_vec.
|
||||
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute dot product.
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
@ -223,7 +243,7 @@ __device__ void paged_attention_kernel(
|
||||
// The 0-th thread of each thread group already has its max qk value.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
@ -235,10 +255,10 @@ __device__ void paged_attention_kernel(
|
||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
// Broadcast the max qk value to all threads.
|
||||
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
@ -272,6 +292,9 @@ __device__ void paged_attention_kernel(
|
||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||
#endif
|
||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||
|
||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||
@ -297,14 +320,25 @@ __device__ void paged_attention_kernel(
|
||||
L_vec logits_vec;
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||
|
||||
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
V_vec v_vec;
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
}
|
||||
if (block_idx == num_context_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
@ -326,7 +360,7 @@ __device__ void paged_attention_kernel(
|
||||
float acc = accs[i];
|
||||
#pragma unroll
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
@ -385,15 +419,17 @@ __device__ void paged_attention_kernel(
|
||||
// Grid: (num_heads, num_seqs, 1).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS>
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE>
|
||||
__global__ void paged_attention_v1_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -402,27 +438,29 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -431,8 +469,8 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
@ -492,7 +530,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
// Reduce within the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = max_logit;
|
||||
@ -502,10 +540,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
// Broadcast the max value to all threads.
|
||||
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
|
||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
@ -539,16 +577,16 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
cudaFuncSetAttribute( \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
@ -561,14 +599,16 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int NUM_THREADS = 128>
|
||||
void paged_attention_v1_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@ -592,9 +632,8 @@ void paged_attention_v1_launcher(
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@ -608,6 +647,7 @@ void paged_attention_v1_launcher(
|
||||
|
||||
dim3 grid(num_heads, num_seqs, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
@ -637,13 +677,13 @@ void paged_attention_v1_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v1_launcher<T, BLOCK_SIZE>( \
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
@ -652,16 +692,16 @@ void paged_attention_v1_launcher(
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V1_LAUNCHER(T, 8); \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V1_LAUNCHER(T, 16); \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V1_LAUNCHER(T, 32); \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
@ -673,26 +713,42 @@ void paged_attention_v1(
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float);
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||
vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
@ -700,7 +756,7 @@ void paged_attention_v1(
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
@ -720,7 +776,9 @@ void paged_attention_v1(
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int NUM_THREADS = 128,
|
||||
int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_launcher(
|
||||
@ -731,7 +789,7 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@ -758,9 +816,8 @@ void paged_attention_v2_launcher(
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@ -777,6 +834,7 @@ void paged_attention_v2_launcher(
|
||||
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
@ -806,8 +864,8 @@ void paged_attention_v2_launcher(
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \
|
||||
paged_attention_v2_launcher<T, BLOCK_SIZE>( \
|
||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
@ -815,7 +873,7 @@ void paged_attention_v2_launcher(
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
@ -824,16 +882,16 @@ void paged_attention_v2_launcher(
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V2_LAUNCHER(T, 8); \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V2_LAUNCHER(T, 16); \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V2_LAUNCHER(T, 32); \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
@ -848,22 +906,37 @@ void paged_attention_v2(
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float);
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
|
@ -17,6 +17,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "attention_dtypes.h"
|
||||
|
||||
#include <float.h>
|
||||
@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
float qk = sum(qk_vec);
|
||||
#pragma unroll
|
||||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
||||
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
||||
}
|
||||
return qk;
|
||||
}
|
||||
|
@ -21,8 +21,17 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
#ifndef USE_ROCM
|
||||
return a + b;
|
||||
#else
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,10 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -63,21 +67,47 @@ struct FloatVec<uint4> {
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||
#ifndef USE_ROCM
|
||||
uint32_t b;
|
||||
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||
return b;
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u16[0] = a;
|
||||
tmp.u16[1] = a;
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ float half_to_float(uint16_t h) {
|
||||
float f;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||
#else
|
||||
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
||||
#endif
|
||||
return f;
|
||||
}
|
||||
|
||||
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||
#ifndef USE_ROCM
|
||||
uint16_t lo, hi;
|
||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u32 = v;
|
||||
float2 ret;
|
||||
ret.x = half_to_float(tmp.u16[0]);
|
||||
ret.y = half_to_float(tmp.u16[1]);
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ uint16_t float_to_half(float f) {
|
||||
@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||
#else
|
||||
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
||||
#endif
|
||||
return tmp.u16[0];
|
||||
}
|
||||
|
||||
@ -94,12 +128,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
#endif
|
||||
#else
|
||||
tmp.u16[0] = float_to_half(f.x);
|
||||
tmp.u16[1] = float_to_half(f.y);
|
||||
#endif
|
||||
return tmp.u32;
|
||||
}
|
||||
@ -107,13 +145,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
// Vector addition.
|
||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||
template<>
|
||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
|
||||
|
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.cuh"
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
// fp8 vector types for quantization of kv cache
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 1> {
|
||||
using Type = uint8_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 2> {
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 4> {
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 8> {
|
||||
using Type = uint2;
|
||||
};
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
|
||||
} // namespace vllm
|
@ -1,47 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
m.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
m.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
m.def(
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
}
|
29
csrc/cache.h
Normal file
29
csrc/cache.h
Normal file
@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8_e5m2(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache);
|
@ -1,13 +1,23 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
@ -28,10 +38,11 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
@ -126,8 +137,9 @@ void copy_blocks(
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
@ -139,12 +151,12 @@ void copy_blocks(
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
@ -181,19 +193,45 @@ __global__ void reshape_and_cache_kernel(
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
key_cache[tgt_key_idx] = key[src_key_idx];
|
||||
value_cache[tgt_value_idx] = value[src_value_idx];
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
key_cache[tgt_key_idx] = tgt_key;
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
@ -206,182 +244,77 @@ void reshape_and_cache(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_blocks, block_size).
|
||||
template<typename scalar_t>
|
||||
__global__ void gather_cached_kv_kernel(
|
||||
scalar_t* __restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
||||
scalar_t* __restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
||||
const scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x) {
|
||||
const int token_idx = blockIdx.x;
|
||||
const int slot_idx = slot_mapping[token_idx];
|
||||
const int block_idx = slot_idx / block_size;
|
||||
const int block_offset = slot_idx % block_size;
|
||||
|
||||
const int num_tokens = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
|
||||
const int tgt_key_idx = token_idx * key_stride + i;
|
||||
const int tgt_value_idx = token_idx * value_stride + i;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_cached_kv_kernel_optimized(
|
||||
scalar_t *__restrict__ key, // [num_tokens, [stride], num_heads, head_size]
|
||||
scalar_t *__restrict__ value, // [num_tokens, [stride], num_heads, head_size]
|
||||
const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int *__restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x)
|
||||
{
|
||||
const int token_idx = blockIdx.x;
|
||||
const int slot_idx = slot_mapping[token_idx];
|
||||
const int block_idx = slot_idx / block_size;
|
||||
const int block_offset = slot_idx % block_size;
|
||||
|
||||
const int dim = num_heads * head_size;
|
||||
assert(dim % 4 == 0); // this is true for known use cases
|
||||
const int unroll_factor = 4;
|
||||
const int unrolled_dim = dim / unroll_factor;
|
||||
|
||||
for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
|
||||
{
|
||||
int tgt_key_indices[unroll_factor];
|
||||
int tgt_value_indices[unroll_factor];
|
||||
int src_key_indices[unroll_factor];
|
||||
int src_value_indices[unroll_factor];
|
||||
scalar_t keys_to_store[unroll_factor];
|
||||
scalar_t values_to_store[unroll_factor];
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < unroll_factor; ++j)
|
||||
{
|
||||
int index = i + j * unrolled_dim;
|
||||
|
||||
const int tgt_key_idx = token_idx * key_stride + index;
|
||||
const int tgt_value_idx = token_idx * value_stride + index;
|
||||
|
||||
const int head_idx = index / head_size;
|
||||
const int head_offset = index % head_size;
|
||||
const int x_idx = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int src_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
tgt_key_indices[j] = tgt_key_idx;
|
||||
tgt_value_indices[j] = tgt_value_idx;
|
||||
src_key_indices[j] = src_key_idx;
|
||||
src_value_indices[j] = src_value_idx;
|
||||
|
||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < unroll_factor; ++j)
|
||||
{
|
||||
key[tgt_key_indices[j]] = keys_to_store[j];
|
||||
value[tgt_value_indices[j]] = values_to_store[j];
|
||||
}
|
||||
template<typename Tout, typename Tin>
|
||||
__global__ void convert_fp8_e5m2_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [out] [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [in] [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [in] [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [in] [num_tokens]
|
||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
block_stride);
|
||||
|
||||
void convert_fp8_e5m2(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = key_cache.size(3);
|
||||
int x = key_cache.size(4);
|
||||
int64_t num_blocks = src_cache.size(0);
|
||||
int64_t block_stride = src_cache.stride(0);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(std::min(block_stride, int64_t(512)));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||
}
|
||||
}
|
||||
|
38
csrc/cuda_compat.h
Normal file
38
csrc/cuda_compat.h
Normal file
@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_LDG(arg) __ldg(arg)
|
||||
#else
|
||||
#define VLLM_LDG(arg) *(arg)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||
#else
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#else
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
@ -1,13 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
}
|
||||
|
10
csrc/cuda_utils.h
Normal file
10
csrc/cuda_utils.h
Normal file
@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id);
|
@ -1,3 +1,7 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
@ -12,3 +16,20 @@ int get_device_attribute(
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id)
|
||||
{
|
||||
int attribute;
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
|
||||
#ifdef USE_ROCM
|
||||
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||
#else
|
||||
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||
#endif
|
||||
|
||||
return get_device_attribute(attribute, device_id);
|
||||
}
|
||||
|
148
csrc/custom_all_reduce.cu
Normal file
148
csrc/custom_all_reduce.cu
Normal file
@ -0,0 +1,148 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
// fake pointer type
|
||||
using fptr_t = uint64_t;
|
||||
static_assert(sizeof(void *) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
cudaIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor &t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink) {
|
||||
auto inp_size = inp.numel() * inp.element_size();
|
||||
// custom allreduce requires input byte size to be multiples of 16
|
||||
if (inp_size % 16 != 0) return false;
|
||||
if (!_is_weak_contiguous(inp)) return false;
|
||||
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
|
||||
// performance improvement over NCCL.
|
||||
return false;
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
cudaStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
||||
reinterpret_cast<float *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
||||
reinterpret_cast<half *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, cudaMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int meta_size() { return sizeof(vllm::Signal); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
return fa->get_graph_buffer_ipc_meta();
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
485
csrc/custom_all_reduce.cuh
Normal file
485
csrc/custom_all_reduce.cuh
Normal file
@ -0,0 +1,485 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
constexpr int kMaxBlocks = 64;
|
||||
// note: we don't want to use atomics for signals because peer atomics are no
|
||||
// supported on PCIe links
|
||||
struct Signal {
|
||||
alignas(128) uint32_t start[kMaxBlocks][8];
|
||||
alignas(128) uint32_t end[kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
||||
|
||||
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half &assign_add(half &a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// This function is meant to be used as the first synchronization in the all
|
||||
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
|
||||
// prior memory accesses. Note: volatile writes will not be reordered against
|
||||
// other volatile writes.
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
int rank) {
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->end[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->start[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// This function is meant to be used as the second or the final synchronization
|
||||
// barrier in the all reduce kernel. If it's the final synchronization barrier,
|
||||
// we don't need to make any visibility guarantees for prior memory accesses.
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
|
||||
int rank) {
|
||||
__syncthreads();
|
||||
// eliminate the case that prior writes are not visible after signals become
|
||||
// visible. Note that I did not managed to make this happen through a lot of
|
||||
// testing. Might be the case that hardware provides stronger guarantee than
|
||||
// the memory model.
|
||||
if constexpr (!final_sync) __threadfence_system();
|
||||
if (threadIdx.x < ngpus) {
|
||||
// reset flag for next time
|
||||
self_sg->start[blockIdx.x][threadIdx.x] = 0;
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
|
||||
// wait until we got true from all ranks
|
||||
while (!self_sg->end[blockIdx.x][threadIdx.x])
|
||||
;
|
||||
}
|
||||
if constexpr (!final_sync) __syncthreads();
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P *)result)[idx] =
|
||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, self_sg, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||
return (P *)(((Signal *)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||
volatile Signal *self_sg, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
int largest_part = part + size % ngpus;
|
||||
const P *ptrs[ngpus];
|
||||
P *tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P *)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, self_sg, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, self_sg, rank);
|
||||
|
||||
// stage 2: allgather. Note: it's important to match the tid between
|
||||
// the two stages, because visibility across devices is only guaranteed
|
||||
// between threads that have the same tid. If thread i computes the sum of
|
||||
// start + i in the first stage, then thread i also gathers start + i from all
|
||||
// ranks.
|
||||
for (int idx = tid; idx < largest_part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int gather_from_rank = ((rank + i) % ngpus);
|
||||
if (gather_from_rank == ngpus - 1 || idx < part) {
|
||||
int dst_idx = gather_from_rank * part + idx;
|
||||
((P *)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void *, RankData *> buffers_;
|
||||
Signal *self_sg_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void *> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char *> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Signal) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz,
|
||||
const cudaIpcMemHandle_t *handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
self_sg_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Signal *rank_sg;
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_sg = (Signal *)handle;
|
||||
} else {
|
||||
rank_sg = self_sg_;
|
||||
}
|
||||
sg_.signals[i] = rank_sg;
|
||||
}
|
||||
}
|
||||
|
||||
char *open_ipc_handle(const void *ipc_handle) {
|
||||
auto [it, new_handle] =
|
||||
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char *ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t *)ipc_handle),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void *base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, void *self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(
|
||||
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto &rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char *handle =
|
||||
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||
sizeof(RankData) * num_buffers,
|
||||
cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
if (block_limit > kMaxBlocks)
|
||||
throw std::runtime_error("max supported block limit is " +
|
||||
std::to_string(kMaxBlocks) + ". Got " +
|
||||
std::to_string(block_limit));
|
||||
|
||||
RankData *ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " +
|
||||
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
|
||||
rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void vllm::CustomAllreduce::allreduce<half>(cudaStream_t, half *,
|
||||
half *, int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
316
csrc/custom_all_reduce_test.cu
Normal file
316
csrc/custom_all_reduce_test.cu
Normal file
@ -0,0 +1,316 @@
|
||||
/**
|
||||
* This is a standalone test for custom allreduce.
|
||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||
* export MPI_HOME=XXX
|
||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||
*
|
||||
* Warning: this C++ test is not designed to be very readable and was used
|
||||
* during the rapid prototyping process.
|
||||
*
|
||||
* To run:
|
||||
* mpirun -np 8 ./custom_all_reduce_test
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "custom_all_reduce.cuh"
|
||||
#include "mpi.h"
|
||||
#include "nccl.h"
|
||||
|
||||
#define MPICHECK(cmd) \
|
||||
do { \
|
||||
int e = cmd; \
|
||||
if (e != MPI_SUCCESS) { \
|
||||
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define NCCLCHECK(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
__global__ void dummy_kernel() {
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_data(T *data, int size, int myRank) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
data[idx] = myRank * 0.11f;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||
double *fdata2, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
fdata1[idx] = data1[idx];
|
||||
fdata2[idx] = data2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||
int myRank, int nRanks, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
double sum = 0.0;
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
||||
T hval = val; // downcast first
|
||||
sum += static_cast<double>(hval);
|
||||
if (i == myRank) data[idx] = hval;
|
||||
}
|
||||
ground_truth[idx] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
int data_size, bool performance_test) {
|
||||
T *result;
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
||||
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
|
||||
|
||||
cudaIpcMemHandle_t self_data_handle;
|
||||
cudaIpcMemHandle_t data_handles[8];
|
||||
vllm::Signal *buffer;
|
||||
T *self_data_copy;
|
||||
/**
|
||||
* Allocate IPC buffer
|
||||
*
|
||||
* The first section is a temporary buffer for storing intermediate allreduce
|
||||
* results, if a particular algorithm requires it. The second section is for
|
||||
* the input to the allreduce. The actual API takes the input pointer as an
|
||||
* argument (that is, they can and usually should be allocated separately).
|
||||
* But since the input pointers and the temporary buffer all require IPC
|
||||
* registration, they are allocated and registered together in the test for
|
||||
* convenience.
|
||||
*/
|
||||
CUDACHECK(
|
||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(
|
||||
cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal)));
|
||||
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
||||
|
||||
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
|
||||
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
||||
MPI_BYTE, MPI_COMM_WORLD));
|
||||
|
||||
void *rank_data;
|
||||
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||
std::vector<int64_t> offsets(nRanks, 0);
|
||||
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||
offsets, myRank);
|
||||
auto *self_data =
|
||||
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
// hack buffer registration
|
||||
{
|
||||
std::vector<std::string> handles;
|
||||
handles.reserve(nRanks);
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
char *begin = (char *)&data_handles[i];
|
||||
char *end = (char *)&data_handles[i + 1];
|
||||
handles.emplace_back(begin, end);
|
||||
}
|
||||
std::vector<int64_t> offsets(nRanks,
|
||||
sizeof(vllm::Signal) + data_size * sizeof(T));
|
||||
fa.register_buffer(handles, offsets, self_data);
|
||||
}
|
||||
|
||||
double *ground_truth;
|
||||
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||
curandState_t *states;
|
||||
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||
nRanks, data_size);
|
||||
CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
cudaEvent_t start, stop;
|
||||
CUDACHECK(cudaEventCreate(&start));
|
||||
CUDACHECK(cudaEventCreate(&stop));
|
||||
|
||||
ncclDataType_t ncclDtype;
|
||||
if (std::is_same<T, half>::value) {
|
||||
ncclDtype = ncclFloat16;
|
||||
} else if (std::is_same<T, nv_bfloat16>::value) {
|
||||
ncclDtype = ncclBfloat16;
|
||||
} else {
|
||||
ncclDtype = ncclFloat;
|
||||
}
|
||||
double *nccl_result, *my_result;
|
||||
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
||||
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
||||
if (performance_test) {
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
constexpr int warmup_iters = 5;
|
||||
constexpr int num_iters = 100;
|
||||
// warmup
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
|
||||
comm, stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum,
|
||||
comm, stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
float allreduce_ms = 0;
|
||||
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// warm up
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
float duration_ms = 0;
|
||||
cudaEventElapsedTime(&duration_ms, start, stop);
|
||||
if (myRank == 0)
|
||||
printf(
|
||||
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||
"time:%.2fus\n",
|
||||
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||
|
||||
// And wait for all the queued up work to complete
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||
my_result, data_size);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 4e-2) {
|
||||
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
long double nccl_diffs = 0.0;
|
||||
long double my_diffs = 0.0;
|
||||
for (int j = 0; j < data_size; j++) {
|
||||
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
}
|
||||
if (myRank == 0)
|
||||
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
<< " me: " << my_diffs / data_size << std::endl;
|
||||
} else {
|
||||
for (int i = 0; i < 100; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads,
|
||||
block_limit);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
NCCLCHECK(ncclAllReduce(self_data, self_data_copy, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(
|
||||
self_data_copy, result, nccl_result, my_result, data_size);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 4e-2) {
|
||||
printf(
|
||||
"Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (myRank == 0)
|
||||
printf("Test passed: nGPUs:%d, sz (kb): %d, %d, %d\n", nRanks,
|
||||
data_size * sizeof(T) / 1024, threads, block_limit);
|
||||
// long double nccl_diffs = 0.0;
|
||||
// long double my_diffs = 0.0;
|
||||
// for (int j = 0; j < data_size; j++) {
|
||||
// nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
// my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
// }
|
||||
// if (myRank == 0)
|
||||
// std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
// << " me: " << my_diffs / data_size << std::endl;
|
||||
}
|
||||
|
||||
CUDACHECK(cudaFree(result));
|
||||
CUDACHECK(cudaFree(self_data_copy));
|
||||
CUDACHECK(cudaFree(rank_data));
|
||||
CUDACHECK(cudaFree(buffer));
|
||||
CUDACHECK(cudaFree(states));
|
||||
CUDACHECK(cudaFreeHost(ground_truth));
|
||||
CUDACHECK(cudaFreeHost(nccl_result));
|
||||
CUDACHECK(cudaFreeHost(my_result));
|
||||
CUDACHECK(cudaStreamDestroy(stream));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int nRanks, myRank;
|
||||
MPICHECK(MPI_Init(&argc, &argv));
|
||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
||||
CUDACHECK(cudaSetDevice(myRank));
|
||||
ncclUniqueId id;
|
||||
ncclComm_t comm;
|
||||
if (myRank == 0) ncclGetUniqueId(&id);
|
||||
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
||||
MPI_COMM_WORLD));
|
||||
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||
|
||||
bool performance_test = true;
|
||||
cudaProfilerStart();
|
||||
// for (int threads : {256, 512}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||
// }
|
||||
// }
|
||||
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
|
||||
}
|
||||
|
||||
cudaProfilerStop();
|
||||
return EXIT_SUCCESS;
|
||||
}
|
@ -2,6 +2,8 @@
|
||||
* Adapted from
|
||||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
@ -12,3 +14,24 @@
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
||||
|
@ -1,24 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
m.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
@ -76,6 +77,7 @@ void rms_norm(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
@ -101,6 +103,7 @@ void fused_add_rms_norm(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
|
7
csrc/moe/moe_ops.cpp
Normal file
7
csrc/moe/moe_ops.cpp
Normal file
@ -0,0 +1,7 @@
|
||||
#include "moe_ops.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
||||
}
|
9
csrc/moe/moe_ops.h
Normal file
9
csrc/moe/moe_ops.h
Normal file
@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
499
csrc/moe/topk_softmax_kernels.cu
Normal file
499
csrc/moe/topk_softmax_kernels.cu
Normal file
@ -0,0 +1,499 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||
* Copyright (c) 2024, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
/// Number of elements in the array
|
||||
int N,
|
||||
/// Alignment requirement in bytes
|
||||
int Alignment = sizeof(T) * N
|
||||
>
|
||||
class alignas(Alignment) AlignedArray {
|
||||
float data[N];
|
||||
};
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing the output
|
||||
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
||||
{
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
if ((finished != nullptr) && finished[blockIdx.x])
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int num_rows = gridDim.x;
|
||||
const int block_row = blockIdx.x;
|
||||
|
||||
const bool row_is_active = finished ? !finished[block_row] : true;
|
||||
const int thread_read_offset = blockIdx.x * num_experts;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
||||
{
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
||||
{
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert)
|
||||
{
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// Ignore experts the node isn't responsible for with expert parallelism
|
||||
const int expert = result_kvp.key;
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||
assert(indices[idx] >= 0);
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
||||
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
||||
a single warp and eliminate communication between warps (so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
||||
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows)
|
||||
{
|
||||
return;
|
||||
}
|
||||
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||
// row it will read.
|
||||
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
||||
// this can support all powers of 2 up to 16.
|
||||
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
||||
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
||||
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
float row_chunk[VPT];
|
||||
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
||||
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
||||
{
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||
// convert to float afterwards for the exp + sum reduction.
|
||||
float thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii)
|
||||
{
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
||||
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
||||
// with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
|
||||
{
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index are processed first and only
|
||||
// updated if > (not >=)
|
||||
if (val > max_val)
|
||||
{
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
||||
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
||||
// then blank out their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this way
|
||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||
{
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0)
|
||||
{
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||
// single) thread per row of the input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||
if (k_idx + 1 < k)
|
||||
{
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group)
|
||||
{
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
||||
gating_output, nullptr, topk_weights, topk_indicies, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const float* gating_output,
|
||||
float* topk_weights,
|
||||
int* topk_indicies,
|
||||
int* token_expert_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
const int num_experts,
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
break;
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
|
||||
num_experts, topk, 0, num_experts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
||||
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
125
csrc/moe_align_block_size_kernels.cu
Normal file
125
csrc/moe_align_block_size_kernels.cu
Normal file
@ -0,0 +1,125 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
|
||||
namespace {
|
||||
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
|
||||
// don't worry about overflow because num_experts is relatively small
|
||||
return row * total_col + col;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
int32_t *sorted_token_ids,
|
||||
int32_t *expert_ids,
|
||||
int32_t *total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel) {
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
|
||||
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||
* which counts how many tokens in the token shard of thread_index are assigned
|
||||
* to expert expert_index.
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// We accumulate the token counts of all experts in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
||||
* and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Each thread processes a token shard, calculating the index of each token after
|
||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
||||
* where * represents a padding value(preset in python).
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
||||
* stores the indices of the tokens processed by the expert with expert_id within
|
||||
* the current thread's token shard.
|
||||
*/
|
||||
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
|
||||
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
|
||||
AT_CUDA_CHECK(
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem));
|
||||
kernel<<<1, num_experts, shared_mem, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
159
csrc/ops.h
Normal file
159
csrc/ops.h
Normal file
@ -0,0 +1,159 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets);
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_tanh_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy);
|
||||
|
||||
torch::Tensor marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama,
|
||||
int bit);
|
||||
|
||||
void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
int bit);
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = uint64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink);
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out);
|
||||
void dispose(fptr_t _fa);
|
||||
int meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets);
|
||||
#endif
|
@ -1,16 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
}
|
@ -1,12 +1,14 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
inline __device__ void apply_token_rotary_embedding(
|
||||
scalar_t* __restrict__ arr,
|
||||
const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr,
|
||||
@ -19,14 +21,14 @@ inline __device__ void apply_rotary_embedding(
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = __ldg(cos_ptr + x_index);
|
||||
sin = __ldg(sin_ptr + x_index);
|
||||
cos = VLLM_LDG(cos_ptr + x_index);
|
||||
sin = VLLM_LDG(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = __ldg(cos_ptr + x_index / 2);
|
||||
sin = __ldg(sin_ptr + x_index / 2);
|
||||
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
@ -35,6 +37,42 @@ inline __device__ void apply_rotary_embedding(
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* cache_ptr,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int rot_dim,
|
||||
const int token_idx,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride)
|
||||
{
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
@ -42,8 +80,8 @@ __global__ void rotary_embedding_kernel(
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int query_stride,
|
||||
const int key_stride,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
@ -52,27 +90,29 @@ __global__ void rotary_embedding_kernel(
|
||||
int64_t pos = positions[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void batched_rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
|
||||
const int rot_dim,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
int64_t pos = positions[token_idx];
|
||||
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
|
||||
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
|
||||
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -88,11 +128,12 @@ void rotary_embedding(
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int query_stride = query.stride(-2);
|
||||
int key_stride = key.stride(-2);
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
@ -125,3 +166,61 @@ void rotary_embedding(
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
Batched version of rotary embedding, pack multiple LoRAs together
|
||||
and process in batched manner.
|
||||
*/
|
||||
void batched_rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox,
|
||||
int rot_dim,
|
||||
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
|
||||
) {
|
||||
int64_t num_tokens = cos_sin_cache_offsets.size(0);
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
if (is_neox) {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
cos_sin_cache_offsets.data_ptr<int64_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
217
csrc/punica/LICENSE
Normal file
217
csrc/punica/LICENSE
Normal file
@ -0,0 +1,217 @@
|
||||
Contains code from https://github.com/punica-ai/punica
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
|
||||
|
||||
Apache-2.0
|
||||
* third_party/nvbench (with LLVM exception)
|
||||
* third_party/flashinfer
|
||||
|
||||
BSD-3-Clause:
|
||||
* third_party/cutlass
|
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
71
csrc/punica/bgmv/bgmv_config.h
Normal file
71
csrc/punica/bgmv/bgmv_config.h
Normal file
@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale);
|
||||
|
||||
// clang-format off
|
||||
|
||||
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
f(in_T, out_T, W_T, narrow, 1280) \
|
||||
f(in_T, out_T, W_T, narrow, 1536) \
|
||||
f(in_T, out_T, W_T, narrow, 1728) \
|
||||
f(in_T, out_T, W_T, narrow, 1792) \
|
||||
f(in_T, out_T, W_T, narrow, 2048) \
|
||||
f(in_T, out_T, W_T, narrow, 2304) \
|
||||
f(in_T, out_T, W_T, narrow, 2560) \
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 2816) \
|
||||
f(in_T, out_T, W_T, narrow, 3072) \
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
f(in_T, out_T, W_T, narrow, 4608) \
|
||||
f(in_T, out_T, W_T, narrow, 5120) \
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 6144) \
|
||||
f(in_T, out_T, W_T, narrow, 6848) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
f(in_T, out_T, W_T, narrow, 8192) \
|
||||
f(in_T, out_T, W_T, narrow, 9216) \
|
||||
f(in_T, out_T, W_T, narrow, 10240) \
|
||||
f(in_T, out_T, W_T, narrow, 11008) \
|
||||
f(in_T, out_T, W_T, narrow, 12288) \
|
||||
f(in_T, out_T, W_T, narrow, 13696) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
f(in_T, out_T, W_T, narrow, 24576) \
|
||||
f(in_T, out_T, W_T, narrow, 27392) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
f(in_T, out_T, W_T, narrow, 32512) \
|
||||
f(in_T, out_T, W_T, narrow, 32768) \
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||
|
||||
// clang-format on
|
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
@ -0,0 +1,294 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cuda/pipeline>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "vec_dtypes.cuh"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// nthrs = (32, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t num_pipeline_stages = 2;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
auto pipe = cuda::make_pipeline();
|
||||
|
||||
// pipeline load W/X and compute WX;
|
||||
pipe.producer_acquire();
|
||||
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
pipe.producer_commit();
|
||||
size_t copy_idx, compute_idx;
|
||||
float y = 0.f;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||
++tile_idx) {
|
||||
copy_idx = tile_idx % num_pipeline_stages;
|
||||
// pipeline stage: async copy W fragment
|
||||
pipe.producer_acquire();
|
||||
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
}
|
||||
pipe.producer_commit();
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// pipeline stage: compute WX
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] = sum;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
}
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// final pipeline stage
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] =
|
||||
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||
? sum
|
||||
: 0.f;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
|
||||
// write Y;
|
||||
if (block.thread_rank() == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||
}
|
||||
}
|
||||
|
||||
// nthrs = (2, 16, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||
typename in_T, typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t tile_idx = blockIdx.x;
|
||||
|
||||
// load X;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||
|
||||
// load W;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||
block.thread_rank() * vec_size);
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
|
||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += g.shfl_down(sum, offset);
|
||||
}
|
||||
sum = g.shfl(sum, 0);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||
}
|
||||
}
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
constexpr size_t vec_size = 8;
|
||||
constexpr int tz = 4;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if constexpr (feat_in < feat_out) {
|
||||
static_assert(feat_in % vec_size == 0);
|
||||
constexpr int tx = feat_in / vec_size;
|
||||
|
||||
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||
|
||||
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||
constexpr int ty = 32 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||
constexpr int ty = 16 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else {
|
||||
constexpr int ty = 8 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
} else {
|
||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||
feat_in % (vec_size * 16) == 0 ||
|
||||
feat_in % (vec_size * 8) == 0);
|
||||
|
||||
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||
vec_size * sizeof(W_T), tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||
constexpr int tx = 16;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||
template void bgmv_kernel<feat_in, feat_out>( \
|
||||
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||
int64_t num_layers, int64_t layer_idx, float scale);
|
||||
|
||||
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
27
csrc/punica/bgmv/generator.py
Normal file
27
csrc/punica/bgmv/generator.py
Normal file
@ -0,0 +1,27 @@
|
||||
DTYPES = ["fp16", "bf16", "fp32"]
|
||||
DTYPE_MAP = {
|
||||
"fp16": "nv_half",
|
||||
"bf16": "nv_bfloat16",
|
||||
"fp32": "float",
|
||||
}
|
||||
|
||||
TEMPLATE = """
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
for input_dtype in DTYPES:
|
||||
for output_dtype in DTYPES:
|
||||
for weight_dtype in DTYPES:
|
||||
if weight_dtype == "fp32":
|
||||
# FP32 weights are not supported.
|
||||
continue
|
||||
kernel_definition = TEMPLATE.format(
|
||||
input_dtype=DTYPE_MAP[input_dtype],
|
||||
output_dtype=DTYPE_MAP[output_dtype],
|
||||
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||
with open(filename, "w") as f:
|
||||
f.write(kernel_definition)
|
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
File diff suppressed because it is too large
Load Diff
565
csrc/punica/punica_ops.cc
Normal file
565
csrc/punica/punica_ops.cc
Normal file
@ -0,0 +1,565 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include "bgmv/bgmv_config.h"
|
||||
|
||||
namespace {
|
||||
|
||||
//====== utils ======
|
||||
|
||||
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||
const char *a_name, const char *b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||
a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||
".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
return (uint32_t(a) << 16) | uint32_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) \
|
||||
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) \
|
||||
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
//====== bgmv ======
|
||||
|
||||
template <typename in_T, typename out_T, typename W_T>
|
||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
const int64_t *lora_indices,
|
||||
uint16_t in_features, uint16_t out_features,
|
||||
int64_t y_offset, int64_t full_y_size,
|
||||
int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
switch (pack_u16(in_features, out_features)) {
|
||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||
case pack_u16(feat_in, feat_out): \
|
||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||
full_y_size, batch_size, num_layers, \
|
||||
layer_idx, scale); \
|
||||
break;
|
||||
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||
#undef CASE
|
||||
#undef CASE_ONESIDE
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t h_in = x.size(1);
|
||||
int64_t h_out = y.size(1);
|
||||
int64_t num_layers = w.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
float scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t num_layers = w.size(1);
|
||||
int64_t full_y_size = y.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//====== pybind ======
|
||||
|
||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||
"dispatch_bgmv_low_level");
|
||||
}
|
126
csrc/pybind.cpp
Normal file
126
csrc/pybind.cpp
Normal file
@ -0,0 +1,126 @@
|
||||
#include "cache.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// vLLM custom ops
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_and_mul",
|
||||
&gelu_and_mul,
|
||||
"Activation function used in GeGLU with `none` approximation.");
|
||||
ops.def(
|
||||
"gelu_tanh_and_mul",
|
||||
&gelu_tanh_and_mul,
|
||||
"Activation function used in GeGLU with `tanh` approximation.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
ops.def(
|
||||
"batched_rotary_embedding",
|
||||
&batched_rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
#endif
|
||||
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
ops.def(
|
||||
"moe_align_block_size",
|
||||
&moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"convert_fp8_e5m2",
|
||||
&convert_fp8_e5m2,
|
||||
"Convert the key and value cache to fp8_e5m2 data type");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Custom all-reduce kernels
|
||||
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
||||
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
||||
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
||||
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
||||
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
||||
custom_ar.def("dispose", &dispose, "dispose");
|
||||
custom_ar.def("meta_size", &meta_size, "meta_size");
|
||||
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
||||
"get_graph_buffer_ipc_meta");
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||
"register_graph_buffers");
|
||||
#endif
|
||||
|
||||
}
|
@ -1,19 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
}
|
@ -27,35 +27,48 @@ __pack_half2(const half x, const half y) {
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
template<int N>
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
int G,
|
||||
int split_k_iters,
|
||||
half* __restrict__ A,
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
int M,
|
||||
int IC,
|
||||
int OC,
|
||||
half* __restrict__ C)
|
||||
{
|
||||
// Only support matrix n = 64 or 128
|
||||
assert(N == 64 || N == 128);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (128 + 8)];
|
||||
__shared__ half B_shared[32 * (N + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[128];
|
||||
__shared__ half zeros_shared[128];
|
||||
__shared__ half scaling_factors_shared[N];
|
||||
__shared__ half zeros_shared[N];
|
||||
|
||||
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||
int j_factors1 = ((OC + N - 1) / N);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[32];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||
half B_shared_warp[N / 4];
|
||||
for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
@ -65,10 +78,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
@ -77,22 +90,22 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ ((int)threadIdx.x) % (128 / 8);
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ ((int)threadIdx.x) % (N / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||
+ ((int)threadIdx.y) * 64
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ ((int)threadIdx.y) * (N / 2)
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
@ -123,7 +136,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
@ -152,7 +165,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@ -174,13 +187,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
);
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
@ -190,7 +203,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
);
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
@ -258,118 +271,45 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
half* __restrict__ C,
|
||||
int G
|
||||
)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (64 + 8)];
|
||||
half B_shared[32 * (128 + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[64];
|
||||
__shared__ half zeros_shared[64];
|
||||
half* B_shared_ptr2 = B_shared;
|
||||
|
||||
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||
half B_shared_warp[32];
|
||||
int OC = 512;
|
||||
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
int N = blockDim.x * gridDim.x; // 2
|
||||
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int index1 = 8 * col + 8 * row * N;
|
||||
half* C_ptr2 = C + index1;
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[16];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
int index2 = col + row * N;
|
||||
int* B_ptr2 = B + index2;
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
int index3 = col + (int)(row / G) * N;
|
||||
int* zeros_ptr2 = zeros + index3;
|
||||
int index4 = 8 * col + (int)(row / G) * N * 8;
|
||||
half* scaling_factors_ptr2 = scaling_factors + index4;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ ((int)threadIdx.x) % (64 / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||
+ ((int)threadIdx.y) * 32
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
@ -378,124 +318,68 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
*(C_ptr2 + i) = B_shared[i];
|
||||
}
|
||||
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#else
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy)
|
||||
{
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
int G = in_c / _scaling_factors.size(0);
|
||||
|
||||
int x_thread = thx;
|
||||
int y_thread = thy;
|
||||
|
||||
int x_blocks = 1;
|
||||
int y_blocks = 1;
|
||||
if (thx==0) {
|
||||
x_thread = qout_c;
|
||||
}
|
||||
if (thy==0) {
|
||||
y_thread = in_c;
|
||||
}
|
||||
if (thx==0 && thy==0) {
|
||||
x_thread = 8;
|
||||
y_thread = 8;
|
||||
x_blocks = (int)(qout_c / 8);
|
||||
y_blocks = (int)(in_c / 8);
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_thread, y_thread);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
kernel, scaling_factors, zeros, de_kernel, G);
|
||||
|
||||
return _de_kernel;
|
||||
}
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
@ -542,8 +426,9 @@ torch::Tensor awq_gemm(
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
@ -553,8 +438,9 @@ torch::Tensor awq_gemm(
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
||||
|
277
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
277
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
@ -0,0 +1,277 @@
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
#include "../../attention/attention_dtypes.h"
|
||||
#include "../../attention/dtype_float32.cuh"
|
||||
#include "../../attention/dtype_float16.cuh"
|
||||
#include "../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
namespace fp8_e5m2_unscaled {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template<>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template<>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template<>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template<>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
// half -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
||||
__nv_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
} // namespace fp8_e5m2_unscaled
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
} // namespace vllm
|
64
csrc/quantization/gptq/compat.cuh
Normal file
64
csrc/quantization/gptq/compat.cuh
Normal file
@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user