Compare commits
530 Commits
Author | SHA1 | Date | |
---|---|---|---|
5f08050d8d | |||
64da65b322 | |||
5255d99dc5 | |||
4f2ad11135 | |||
d7afab6d3a | |||
31348dff03 | |||
25e86b6a61 | |||
4efbac6d35 | |||
87069ccf68 | |||
7e45107f51 | |||
0c48b37c31 | |||
7eacffd951 | |||
2a543d6efe | |||
317b29de0f | |||
a463c333dd | |||
ea356004d4 | |||
5c976a7e1a | |||
f964493274 | |||
a4211a4dc3 | |||
563836496a | |||
4ca2c358b1 | |||
0580aab02f | |||
3711811b1d | |||
65b89d16ee | |||
931746bc6d | |||
c81dddb45c | |||
fe6d09ae61 | |||
ed70c70ea3 | |||
f0d4e14557 | |||
2ccee3def6 | |||
b92adec8e8 | |||
56f738ae9b | |||
72d3a30c63 | |||
c9b45adeeb | |||
5a6c81b051 | |||
51cd22ce56 | |||
5ed704ec8c | |||
4abf6336ec | |||
0e163fce18 | |||
96b6f475dd | |||
c410f5d020 | |||
bb8c697ee0 | |||
b9e96b17de | |||
923797fea4 | |||
cd9e60c76c | |||
93b38bea5d | |||
d0d93b92b1 | |||
89efcf1ce5 | |||
c664b0e683 | |||
d69ff0cbbb | |||
1af090b57d | |||
3dad944485 | |||
105a40f53a | |||
bbe9bd9684 | |||
4f65af0e25 | |||
d79ced3292 | |||
ab40644669 | |||
5d60def02c | |||
ea8489fce2 | |||
1b20639a43 | |||
b72af8f1ed | |||
9090bf02e7 | |||
7d648418b8 | |||
89be30fa7d | |||
f8ecb84c02 | |||
5f036d2bcc | |||
380170038e | |||
220a47627b | |||
beb89f68b4 | |||
390b495ff3 | |||
3a0e1fc070 | |||
6b7de1a030 | |||
5265631d15 | |||
2832e7b9f9 | |||
3a7dd7e367 | |||
223c19224b | |||
f1f6cc10c7 | |||
3209b49033 | |||
1e4277d2d1 | |||
9b945daaf1 | |||
9c1352eb57 | |||
7a0b011dd5 | |||
63e835cbcc | |||
94b5edeb53 | |||
ab7e6006d6 | |||
18bfcdd05c | |||
71d63ed72e | |||
d75c40734a | |||
5b23c3f26f | |||
00efdc84ba | |||
91a61da9b1 | |||
ef9b636e2d | |||
2709c0009a | |||
dd7e8f5f64 | |||
d2a68364c4 | |||
7e1081139d | |||
18473cf498 | |||
4df417d059 | |||
5d80a9178b | |||
8a25d3a71a | |||
d10f8e1d43 | |||
14cc317ba4 | |||
e1957c6ebd | |||
8cd5a992bf | |||
947f0b23cc | |||
f780504d12 | |||
bfc072addf | |||
2a18da257c | |||
6e01e8c1c8 | |||
9f659bf07f | |||
35c4bc20d9 | |||
218dc2ccda | |||
827cbcd37c | |||
cb7a1c1cbf | |||
7878958c0d | |||
ce036244c9 | |||
48cf1e413c | |||
97460585d9 | |||
f745847ef7 | |||
6549aef245 | |||
50376faa7b | |||
4b61c6b669 | |||
79d64c4954 | |||
74cd5abdd1 | |||
28c3f12104 | |||
c884819135 | |||
05921a9a7a | |||
d0215a58e7 | |||
937e7b7d7c | |||
aee8ef661a | |||
2e0b6e7757 | |||
941767127c | |||
74d8d77626 | |||
fd4ea8ef5c | |||
1066cbd152 | |||
6ef00b03a2 | |||
9140561059 | |||
77af974b40 | |||
4934d49274 | |||
358c328d69 | |||
4aaafdd289 | |||
66b108d142 | |||
e0ff920001 | |||
face83c7ec | |||
1db83e31a2 | |||
a1b9cb2a34 | |||
3a4fd5ca59 | |||
c17daa9f89 | |||
bd29cf3d3a | |||
31bff69151 | |||
ba4f826738 | |||
de60a3fb93 | |||
21d5daa4ac | |||
290e015c6c | |||
1b7c791d60 | |||
bbe4466fd9 | |||
08133c4d1a | |||
76a7983b23 | |||
8041b7305e | |||
3ec8c25cd0 | |||
671af2b1c0 | |||
6f41f0e377 | |||
2c9b638065 | |||
a7347d9a6d | |||
f8c688d746 | |||
c9fadda543 | |||
30fb0956df | |||
3a765bd5e1 | |||
26c52a5ea6 | |||
c3372e87be | |||
b0a1d667b0 | |||
e1d5402238 | |||
3d1cfbfc74 | |||
37ca558103 | |||
eed74a558f | |||
2acd76f346 | |||
b81a6a6bb3 | |||
0fbfc4b81b | |||
c06170cc8e | |||
614856da25 | |||
05bdf4eaf3 | |||
6774bd50b0 | |||
31c1f3255e | |||
21d93c140d | |||
f1c8520146 | |||
096827c284 | |||
6565d9e33e | |||
f375ec8440 | |||
518369d78c | |||
30bad5c492 | |||
3fefe271ec | |||
6428f1d051 | |||
7e1b21daac | |||
cb3f30c600 | |||
f3e024bece | |||
31d2ab4aff | |||
eb17212858 | |||
4dd4b5c538 | |||
6120e5aaea | |||
2eaa81b236 | |||
81ce2a4b26 | |||
5dd80d3777 | |||
beeee69bc9 | |||
9bf28d0b69 | |||
c0ce15dfb2 | |||
b9bcdc7158 | |||
4ff0203987 | |||
b5f882cc98 | |||
2e8fc0d4c3 | |||
dacaf5a400 | |||
24cde76a15 | |||
1aa1361510 | |||
fe470ae5ad | |||
3a8c2381f7 | |||
c85b80c2b6 | |||
2b981012a6 | |||
6ccc0bfffb | |||
c8e7eb1eb3 | |||
24f60a54f4 | |||
42c02f5892 | |||
ebede26ebf | |||
d940ce497e | |||
05ff90b692 | |||
1d9b737e05 | |||
60dc62dc9e | |||
0f90effc66 | |||
464dd985e3 | |||
c07a442854 | |||
cd3aa153a4 | |||
9b294976a2 | |||
5313c2cb8b | |||
5f09cbdb63 | |||
4cefa9b49b | |||
f86bd6190a | |||
e5452ddfd6 | |||
d06980dfa7 | |||
66785cc05c | |||
05a38612b0 | |||
d27f4bae39 | |||
8d8c2f6ffe | |||
51d3cb951d | |||
e74b1736a1 | |||
f07c1ceaa5 | |||
63b2206ad0 | |||
27feead2f8 | |||
c782195662 | |||
0f621c2c7d | |||
a9e4574261 | |||
0229c386c5 | |||
a7b3e33078 | |||
e19a64c7ef | |||
1cb4ad8de9 | |||
6ed068a71a | |||
708e6c18b0 | |||
b943890484 | |||
a1125ad4df | |||
a8b150c595 | |||
665cbcec4b | |||
7c600440f7 | |||
e0c6f556e8 | |||
de23687d16 | |||
4cea74c73b | |||
a921d8be9d | |||
094f716bf2 | |||
7d761fe3c1 | |||
cf35d8f3d7 | |||
4bb6b67188 | |||
819b18e7ba | |||
19849db573 | |||
3d4ceb292c | |||
f5a37c6c6c | |||
32c927b53f | |||
5ffc0d13a2 | |||
112627e8b2 | |||
37c1e3c218 | |||
06e9ebebd5 | |||
c5f7740d89 | |||
be66d9b125 | |||
e1054247ba | |||
8d17774f92 | |||
e946260cf3 | |||
edb305584b | |||
bb00f66e19 | |||
e87557b069 | |||
dcc543a298 | |||
0fc280b06c | |||
20d0699d49 | |||
686f5e3210 | |||
415d109527 | |||
521b35f799 | |||
cb08cd0d75 | |||
2a2c135b41 | |||
65ea2ddf17 | |||
b514d3c496 | |||
7076fa1c9f | |||
660a7fcfa4 | |||
054072bee5 | |||
eb825c1e74 | |||
1b290ace4f | |||
0d578228ca | |||
aebfcb262a | |||
ab9e8488d5 | |||
fd58b73a40 | |||
8efe23f150 | |||
06458a0b42 | |||
1a2bbc9301 | |||
e7f579eb97 | |||
8516999495 | |||
9f669a9a7c | |||
555bdcc5a3 | |||
54ca1ba71d | |||
9738b84a08 | |||
1fe0990023 | |||
7e90a2d117 | |||
5687d584fe | |||
cf8849f2d6 | |||
e575df33b1 | |||
0ce8647dc5 | |||
9cabcb7645 | |||
7b895c5976 | |||
7013a80170 | |||
79a30912b8 | |||
2f3d36a8a1 | |||
ac8d36f3e5 | |||
15f5632365 | |||
aa9af07cac | |||
69be658bba | |||
beac8dd461 | |||
28b47d1e49 | |||
1f24755bf8 | |||
bf31d3606a | |||
d189170b6c | |||
f61dc8072f | |||
f8a1e39fae | |||
a132435204 | |||
9524867701 | |||
c1376e0f82 | |||
651c614aa4 | |||
d3a5bd9fb7 | |||
e8ef4c0820 | |||
348897af31 | |||
9d9072a069 | |||
928de46888 | |||
29678cd213 | |||
d0740dff1b | |||
de89472897 | |||
e7c8555d06 | |||
ec3b5ce9cc | |||
6368e777a8 | |||
875afe38ab | |||
ee8217e5be | |||
980dd4a2c4 | |||
8285736840 | |||
91fce82c6f | |||
ac5cf86aa6 | |||
6a6119554c | |||
b95ee898fe | |||
9eed4d1f3e | |||
6b5296aa3a | |||
ee92b58b3a | |||
09ff7f106a | |||
acbed3ef40 | |||
66d18a7fb0 | |||
ba0bfd40e2 | |||
84e4e37d14 | |||
a60b353005 | |||
ebe4d1db3a | |||
b5a10eb0ef | |||
0967102c6d | |||
e2fb71ec9f | |||
f936657eb6 | |||
6f88f762bf | |||
202351d5bf | |||
2e8e49fce3 | |||
a8e98aee0c | |||
bb1ba58f06 | |||
7bedab5748 | |||
20f7cc4cde | |||
649aa730c5 | |||
a19bc5c628 | |||
28e616c4e3 | |||
30e775281d | |||
21877b0d75 | |||
cf5cb1e33e | |||
03ffd0a022 | |||
a425bd9a9a | |||
bbbf86565f | |||
9f6be8692e | |||
f187877945 | |||
947b794146 | |||
8d926e91f1 | |||
4ee52bb169 | |||
7d7e3b78a3 | |||
f98b745a81 | |||
2d1e86f1b1 | |||
1ac4ccf73c | |||
2ac4d5e2bf | |||
3302f0aef3 | |||
6f2dd6c37e | |||
bc0644574c | |||
400b8289f7 | |||
c1026311b5 | |||
2b1c116b5a | |||
cc796b1358 | |||
f029ef94d7 | |||
95592fa00a | |||
fbe66e1d0b | |||
90979c38f8 | |||
e21d7687a9 | |||
ff36139ffc | |||
e3e79e9e8a | |||
b9fe4616f9 | |||
64ca424e75 | |||
b5f93d0631 | |||
a58936966f | |||
dd54a4b026 | |||
eda1a7cad3 | |||
f04908cae7 | |||
ab019eea75 | |||
9841d48a10 | |||
3272d7a0b7 | |||
0bb1e885a0 | |||
d6545ad22e | |||
90eb3f43ca | |||
e67b4f2c2a | |||
d6770d1f23 | |||
b9cecc2635 | |||
898285c9bf | |||
a62de9ecfd | |||
4042d192f5 | |||
1117aa1411 | |||
080438477f | |||
4b5bcf8906 | |||
852ef5b4f5 | |||
db09d4ad83 | |||
c957c741d9 | |||
c07ece5ca4 | |||
7a9c20c715 | |||
005ba458b5 | |||
320a622ec4 | |||
c9927c1a6a | |||
fbd80ad409 | |||
22379d5513 | |||
1696725879 | |||
002800f081 | |||
e15932bb60 | |||
ce741ba3e4 | |||
bf87484efa | |||
8ce9c50d40 | |||
32b6816e55 | |||
c128d69856 | |||
55b28b1eee | |||
e11222333f | |||
28873a2799 | |||
0080d8329d | |||
0d93f15694 | |||
becd7a56f1 | |||
75471386de | |||
d2b2eed67c | |||
4b6f069b6f | |||
791d79de32 | |||
94d2f59895 | |||
75c0ca9d43 | |||
2a4ec90854 | |||
85ebcda94d | |||
d64bf1646c | |||
a41c20435e | |||
eedac9dba0 | |||
14f9c72bfd | |||
ad5f2fe34c | |||
4f8584756d | |||
65fc1c3127 | |||
c393af6cd7 | |||
0c04ce3234 | |||
73b3de79ea | |||
d1744376ae | |||
805de738f6 | |||
1b151ed181 | |||
e06f504a76 | |||
462ae5220a | |||
66c54aa9c3 | |||
735ecfff61 | |||
a57d13cc96 | |||
79af7e96a0 | |||
621980bdc0 | |||
aa84c92ef6 | |||
f7389f4763 | |||
55fe8a81ec | |||
e8ddc08ec8 | |||
1b0bd0fe8a | |||
20044cab7a | |||
64f23c2900 | |||
d4c7755ca8 | |||
aa39e42c5a | |||
953f28cf9a | |||
c0d00f5be6 | |||
58a072be15 | |||
82ad323dee | |||
df5dd3c68e | |||
2d867b55fa | |||
d7a1c6d614 | |||
7d5a155e4a | |||
1dde34e0f8 | |||
6fc2a38b11 | |||
c487a221ee | |||
9925c17940 | |||
8c4b2592fb | |||
cf21a9bd5c | |||
16c3e295a8 | |||
bda41c70dd | |||
453bafb96f | |||
328d231c17 | |||
b4b195b360 | |||
20b0d88d16 | |||
2bdea7ac11 | |||
58df2883cb | |||
6d7d95a70a | |||
96853af5a8 | |||
dbed69058c | |||
7b6ae94059 | |||
c6dfc3cdbe | |||
51be365143 | |||
c894836108 | |||
75beba29b5 | |||
ddfdf470ae | |||
b6fbb9a565 | |||
2179e4f4c5 | |||
a945fcc2ae | |||
be54f8e5c4 | |||
b396cb4998 |
69
.buildkite/run-benchmarks.sh
Normal file
@ -0,0 +1,69 @@
|
||||
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
# cd into parent directory of this file
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")/.."
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# run python-based benchmarks and upload the result to buildkite
|
||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
|
||||
bench_latency_exit_code=$?
|
||||
|
||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
|
||||
bench_throughput_exit_code=$?
|
||||
|
||||
# run server-based benchmarks and upload the result to buildkite
|
||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
||||
server_pid=$!
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
# wait for server to start, timeout after 600 seconds
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend openai \
|
||||
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer meta-llama/Llama-2-7b-chat-hf \
|
||||
--save-result \
|
||||
2>&1 | tee benchmark_serving.txt
|
||||
bench_serving_exit_code=$?
|
||||
kill $server_pid
|
||||
|
||||
# write the results into a markdown file
|
||||
echo "### Latency Benchmarks" >> benchmark_results.md
|
||||
sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
|
||||
echo "" >> benchmark_results.md
|
||||
sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
|
||||
|
||||
echo "### Throughput Benchmarks" >> benchmark_results.md
|
||||
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
|
||||
echo "" >> benchmark_results.md
|
||||
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
|
||||
|
||||
echo "### Serving Benchmarks" >> benchmark_results.md
|
||||
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
|
||||
echo "" >> benchmark_results.md
|
||||
tail -n 13 benchmark_serving.txt >> benchmark_results.md # last 13 lines
|
||||
|
||||
# upload the results to buildkite
|
||||
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
|
||||
|
||||
# exit with the exit code of the benchmarks
|
||||
if [ $bench_latency_exit_code -ne 0 ]; then
|
||||
exit $bench_latency_exit_code
|
||||
fi
|
||||
|
||||
if [ $bench_throughput_exit_code -ne 0 ]; then
|
||||
exit $bench_throughput_exit_code
|
||||
fi
|
||||
|
||||
if [ $bench_serving_exit_code -ne 0 ]; then
|
||||
exit $bench_serving_exit_code
|
||||
fi
|
||||
|
||||
/workspace/buildkite-agent artifact upload openai-*.json
|
58
.buildkite/test-pipeline.yaml
Normal file
@ -0,0 +1,58 @@
|
||||
# In this file, you can add more tests to run either by adding a new step or
|
||||
# adding a new command to an existing step. See different options here for examples.
|
||||
# This script will be feed into Jinja template in `test-template.j2` to generate
|
||||
# the final pipeline yaml file.
|
||||
|
||||
steps:
|
||||
- label: Regression Test
|
||||
command: pytest -v -s test_regression.py
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: AsyncEngine Test
|
||||
command: pytest -v -s async_engine
|
||||
|
||||
- label: Distributed Test
|
||||
command: pytest -v -s test_comm_ops.py
|
||||
working_dir: "/vllm-workspace/tests/distributed"
|
||||
num_gpus: 2 # only support 1 or 2 for now.
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine
|
||||
|
||||
- label: Entrypoints Test
|
||||
command: pytest -v -s entrypoints
|
||||
|
||||
- label: Kernels Test
|
||||
command: pytest -v -s kernels
|
||||
soft_fail: true
|
||||
|
||||
- label: Models Test
|
||||
commands:
|
||||
- pytest -v -s models --forked
|
||||
soft_fail: true
|
||||
|
||||
- label: Prefix Caching Test
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
- label: Samplers Test
|
||||
command: pytest -v -s samplers --forked
|
||||
|
||||
- label: Worker Test
|
||||
command: pytest -v -s worker
|
||||
|
||||
- label: LoRA Test
|
||||
command: pytest -v -s lora
|
||||
|
||||
- label: Benchmarks
|
||||
working_dir: "/vllm-workspace/.buildkite"
|
||||
commands:
|
||||
- pip install aiohttp
|
||||
- bash run-benchmarks.sh
|
||||
|
||||
- label: Documentation Build
|
||||
working_dir: "/vllm-workspace/docs"
|
||||
no_gpu: True
|
||||
commands:
|
||||
- pip install -r requirements-docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
56
.buildkite/test-template.j2
Normal file
@ -0,0 +1,56 @@
|
||||
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
|
||||
{% set default_num_gpu = 1 %}
|
||||
{% set default_working_dir = "/vllm-workspace/tests" %}
|
||||
|
||||
steps:
|
||||
- label: ":docker: build image"
|
||||
commands:
|
||||
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
|
||||
- "docker push {{ docker_image }}"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
- wait
|
||||
|
||||
{% for step in steps %}
|
||||
- label: "{{ step.label }}"
|
||||
agents:
|
||||
queue: kubernetes
|
||||
soft_fail: {{ step.soft_fail or false }}
|
||||
retry:
|
||||
automatic:
|
||||
- exit_status: -1 # Agent was lost
|
||||
limit: 5
|
||||
plugins:
|
||||
- kubernetes:
|
||||
podSpec:
|
||||
volumes:
|
||||
- name: dshm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
containers:
|
||||
- image: "{{ docker_image }}"
|
||||
command: ["bash"]
|
||||
args:
|
||||
- '-c'
|
||||
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
|
||||
{% if not step.no_gpu %}
|
||||
resources:
|
||||
requests:
|
||||
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||
limits:
|
||||
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
|
||||
{% endif %}
|
||||
env:
|
||||
- name: HF_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: hf-token-secret
|
||||
key: token
|
||||
volumeMounts:
|
||||
- mountPath: /dev/shm
|
||||
name: dshm
|
||||
{% endfor %}
|
1
.dockerignore
Normal file
@ -0,0 +1 @@
|
||||
vllm/*.so
|
102
.github/workflows/publish.yml
vendored
Normal file
@ -0,0 +1,102 @@
|
||||
# This workflow will upload a Python Package to Release asset
|
||||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
|
||||
|
||||
name: Create Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
# Needed to create release and upload assets
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
# Retrieve tag and create release
|
||||
name: Create Release
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
upload_url: ${{ steps.create_release.outputs.upload_url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Extract branch info
|
||||
shell: bash
|
||||
run: |
|
||||
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
|
||||
|
||||
- name: Create Release
|
||||
id: create_release
|
||||
uses: "actions/github-script@v6"
|
||||
env:
|
||||
RELEASE_TAG: ${{ env.release_tag }}
|
||||
with:
|
||||
github-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
script: |
|
||||
const script = require('.github/workflows/scripts/create_release.js')
|
||||
await script(github, context, core)
|
||||
|
||||
wheel:
|
||||
name: Build Wheel
|
||||
runs-on: ${{ matrix.os }}
|
||||
needs: release
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Linux Env
|
||||
if: ${{ runner.os == 'Linux' }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/env.sh
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
|
||||
|
||||
- name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
|
||||
|
||||
- name: Build wheel
|
||||
shell: bash
|
||||
run: |
|
||||
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
|
||||
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
|
||||
asset_name=${wheel_name//"linux"/"manylinux1"}
|
||||
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
|
||||
echo "asset_name=${asset_name}" >> $GITHUB_ENV
|
||||
|
||||
- name: Upload Release Asset
|
||||
uses: actions/upload-release-asset@v1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ needs.release.outputs.upload_url }}
|
||||
asset_path: ./dist/${{ env.wheel_name }}
|
||||
asset_name: ${{ env.asset_name }}
|
||||
asset_content_type: application/*
|
||||
|
||||
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
|
||||
# - name: Publish package
|
||||
# uses: pypa/gh-action-pypi-publish@release/v1.8
|
||||
# with:
|
||||
# repository-url: https://test.pypi.org/legacy/
|
||||
# password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
# skip-existing: true
|
@ -1,4 +1,4 @@
|
||||
name: pylint
|
||||
name: ruff
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
@ -11,7 +11,7 @@ on:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
@ -25,7 +25,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pylint==2.8.2
|
||||
- name: Analysing the code with pylint
|
||||
pip install ruff==0.1.5
|
||||
- name: Analysing the code with ruff
|
||||
run: |
|
||||
pylint vllm
|
||||
ruff vllm tests
|
20
.github/workflows/scripts/build.sh
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
python_executable=python$1
|
||||
cuda_home=/usr/local/cuda-$2
|
||||
|
||||
# Update paths
|
||||
PATH=${cuda_home}/bin:$PATH
|
||||
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Install requirements
|
||||
$python_executable -m pip install wheel packaging
|
||||
$python_executable -m pip install -r requirements.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
# Make sure punica is built for the release (for LoRA)
|
||||
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
# Build
|
||||
$python_executable setup.py bdist_wheel --dist-dir=dist
|
20
.github/workflows/scripts/create_release.js
vendored
Normal file
@ -0,0 +1,20 @@
|
||||
// Uses Github's API to create the release and wait for result.
|
||||
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
|
||||
|
||||
module.exports = async (github, context, core) => {
|
||||
try {
|
||||
const response = await github.rest.repos.createRelease({
|
||||
draft: false,
|
||||
generate_release_notes: true,
|
||||
name: process.env.RELEASE_TAG,
|
||||
owner: context.repo.owner,
|
||||
prerelease: false,
|
||||
repo: context.repo.repo,
|
||||
tag_name: process.env.RELEASE_TAG,
|
||||
});
|
||||
|
||||
core.setOutput('upload_url', response.data.upload_url);
|
||||
} catch (error) {
|
||||
core.setFailed(error.message);
|
||||
}
|
||||
}
|
23
.github/workflows/scripts/cuda-install.sh
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Replace '.' with '-' ex: 11.8 -> 11-8
|
||||
cuda_version=$(echo $1 | tr "." "-")
|
||||
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
|
||||
OS=$(echo $2 | tr -d ".\-")
|
||||
|
||||
# Installs CUDA
|
||||
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
|
||||
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
||||
rm cuda-keyring_1.1-1_all.deb
|
||||
sudo apt -qq update
|
||||
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
|
||||
sudo apt clean
|
||||
|
||||
# Test nvcc
|
||||
PATH=/usr/local/cuda-$1/bin:${PATH}
|
||||
nvcc --version
|
||||
|
||||
# Log gcc, g++, c++ versions
|
||||
gcc --version
|
||||
g++ --version
|
||||
c++ --version
|
56
.github/workflows/scripts/env.sh
vendored
Normal file
@ -0,0 +1,56 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This file installs common linux environment tools
|
||||
|
||||
export LANG C.UTF-8
|
||||
|
||||
# python_version=$1
|
||||
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
software-properties-common \
|
||||
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
apt-utils \
|
||||
ca-certificates \
|
||||
wget \
|
||||
git \
|
||||
vim \
|
||||
libssl-dev \
|
||||
curl \
|
||||
unzip \
|
||||
unrar \
|
||||
cmake \
|
||||
net-tools \
|
||||
sudo \
|
||||
autotools-dev \
|
||||
rsync \
|
||||
jq \
|
||||
openssh-server \
|
||||
tmux \
|
||||
screen \
|
||||
htop \
|
||||
pdsh \
|
||||
openssh-client \
|
||||
lshw \
|
||||
dmidecode \
|
||||
util-linux \
|
||||
automake \
|
||||
autoconf \
|
||||
libtool \
|
||||
net-tools \
|
||||
pciutils \
|
||||
libpci-dev \
|
||||
libaio-dev \
|
||||
libcap2 \
|
||||
libtinfo5 \
|
||||
fakeroot \
|
||||
devscripts \
|
||||
debhelper \
|
||||
nfs-common
|
||||
|
||||
# Remove github bloat files to free up disk space
|
||||
sudo rm -rf "/usr/local/share/boost"
|
||||
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
sudo rm -rf "/usr/share/dotnet"
|
15
.github/workflows/scripts/pytorch-install.sh
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
python_executable=python$1
|
||||
pytorch_version=$2
|
||||
cuda_version=$3
|
||||
|
||||
# Install torch
|
||||
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
|
||||
$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --extra-index-url https://download.pytorch.org/whl/cu${cuda_version//./}
|
||||
|
||||
# Print version information
|
||||
$python_executable --version
|
||||
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
|
||||
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
|
||||
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
|
2
.github/workflows/yapf.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
|
||||
yapf --diff --recursive .
|
||||
|
11
.gitignore
vendored
@ -173,3 +173,14 @@ cython_debug/
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
|
||||
# Benchmark dataset
|
||||
*.json
|
||||
|
434
.pylintrc
@ -1,434 +0,0 @@
|
||||
# This Pylint rcfile contains a best-effort configuration to uphold the
|
||||
# best-practices and style described in the Google Python style guide:
|
||||
# https://google.github.io/styleguide/pyguide.html
|
||||
#
|
||||
# Its canonical open-source location is:
|
||||
# https://google.github.io/styleguide/pylintrc
|
||||
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=docs,parallel_utils
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
ignore-patterns=
|
||||
|
||||
# Pickle collected data for later comparisons.
|
||||
persistent=no
|
||||
|
||||
# List of plugins (as comma separated values of python modules names) to load,
|
||||
# usually to register additional checkers.
|
||||
load-plugins=
|
||||
|
||||
# Use multiple processes to speed up Pylint.
|
||||
jobs=4
|
||||
|
||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||
# active Python interpreter and may run arbitrary code.
|
||||
unsafe-load-any-extension=no
|
||||
|
||||
|
||||
[MESSAGES CONTROL]
|
||||
|
||||
# Only show warnings with the listed confidence levels. Leave empty to show
|
||||
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
|
||||
confidence=
|
||||
|
||||
# Enable the message, report, category or checker with the given id(s). You can
|
||||
# either give multiple identifier separated by comma (,) or put this option
|
||||
# multiple time (only on the command line, not in the configuration file where
|
||||
# it should appear only once). See also the "--disable" option for examples.
|
||||
#enable=
|
||||
|
||||
# Disable the message, report, category or checker with the given id(s). You
|
||||
# can either give multiple identifiers separated by comma (,) or put this
|
||||
# option multiple times (only on the command line, not in the configuration
|
||||
# file where it should appear only once).You can also use "--disable=all" to
|
||||
# disable everything first and then reenable specific checks. For example, if
|
||||
# you want to run only the similarities checker, you can use "--disable=all
|
||||
# --enable=similarities". If you want to run only the classes checker, but have
|
||||
# no Warning level messages displayed, use"--disable=all --enable=classes
|
||||
# --disable=W"
|
||||
disable=abstract-method,
|
||||
apply-builtin,
|
||||
arguments-differ,
|
||||
attribute-defined-outside-init,
|
||||
backtick,
|
||||
bad-option-value,
|
||||
basestring-builtin,
|
||||
buffer-builtin,
|
||||
c-extension-no-member,
|
||||
consider-using-enumerate,
|
||||
cmp-builtin,
|
||||
cmp-method,
|
||||
coerce-builtin,
|
||||
coerce-method,
|
||||
delslice-method,
|
||||
div-method,
|
||||
duplicate-code,
|
||||
eq-without-hash,
|
||||
execfile-builtin,
|
||||
file-builtin,
|
||||
filter-builtin-not-iterating,
|
||||
fixme,
|
||||
getslice-method,
|
||||
global-statement,
|
||||
hex-method,
|
||||
idiv-method,
|
||||
implicit-str-concat-in-sequence,
|
||||
import-error,
|
||||
import-self,
|
||||
import-star-module-level,
|
||||
inconsistent-return-statements,
|
||||
input-builtin,
|
||||
intern-builtin,
|
||||
invalid-str-codec,
|
||||
locally-disabled,
|
||||
logging-fstring-interpolation, # added by vLLM
|
||||
logging-not-lazy, # added by vLLM
|
||||
long-builtin,
|
||||
long-suffix,
|
||||
map-builtin-not-iterating,
|
||||
misplaced-comparison-constant,
|
||||
missing-class-docstring, # TODO (vLLM): enable
|
||||
missing-function-docstring,
|
||||
missing-module-docstring, # TODO (vLLM): enable
|
||||
metaclass-assignment,
|
||||
next-method-called,
|
||||
next-method-defined,
|
||||
no-absolute-import,
|
||||
no-else-break,
|
||||
no-else-continue,
|
||||
no-else-raise,
|
||||
no-else-return,
|
||||
no-init, # added
|
||||
no-member,
|
||||
no-name-in-module,
|
||||
no-self-use,
|
||||
nonzero-method,
|
||||
oct-method,
|
||||
old-division,
|
||||
old-ne-operator,
|
||||
old-octal-literal,
|
||||
old-raise-syntax,
|
||||
parameter-unpacking,
|
||||
print-statement,
|
||||
raising-string,
|
||||
range-builtin-not-iterating,
|
||||
raw_input-builtin,
|
||||
rdiv-method,
|
||||
reduce-builtin,
|
||||
relative-import,
|
||||
reload-builtin,
|
||||
round-builtin,
|
||||
setslice-method,
|
||||
signature-differs,
|
||||
standarderror-builtin,
|
||||
suppressed-message,
|
||||
sys-max-int,
|
||||
too-few-public-methods,
|
||||
too-many-ancestors,
|
||||
too-many-arguments,
|
||||
too-many-boolean-expressions,
|
||||
too-many-branches,
|
||||
too-many-instance-attributes,
|
||||
too-many-locals,
|
||||
too-many-nested-blocks,
|
||||
too-many-public-methods,
|
||||
too-many-return-statements,
|
||||
too-many-statements,
|
||||
trailing-newlines,
|
||||
unichr-builtin,
|
||||
unicode-builtin,
|
||||
unnecessary-pass,
|
||||
unpacking-in-except,
|
||||
unspecified-encoding,
|
||||
useless-else-on-loop,
|
||||
useless-object-inheritance,
|
||||
useless-suppression,
|
||||
using-cmp-argument,
|
||||
wrong-import-order,
|
||||
xrange-builtin,
|
||||
zip-builtin-not-iterating,
|
||||
|
||||
|
||||
[REPORTS]
|
||||
|
||||
# Set the output format. Available formats are text, parseable, colorized, msvs
|
||||
# (visual studio) and html. You can also give a reporter class, eg
|
||||
# mypackage.mymodule.MyReporterClass.
|
||||
output-format=text
|
||||
|
||||
# Tells whether to display a full report or only the messages
|
||||
reports=no
|
||||
|
||||
# Python expression which should return a note less than 10 (10 is the highest
|
||||
# note). You have access to the variables errors warning, statement which
|
||||
# respectively contain the number of errors / warnings messages and the total
|
||||
# number of statements analyzed. This is used by the global evaluation report
|
||||
# (RP0004).
|
||||
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
|
||||
|
||||
# Template used to display messages. This is a python new-style format string
|
||||
# used to format the message information. See doc for all details
|
||||
#msg-template=
|
||||
|
||||
|
||||
[BASIC]
|
||||
|
||||
# Good variable names which should always be accepted, separated by a comma
|
||||
good-names=main,_
|
||||
|
||||
# Bad variable names which should always be refused, separated by a comma
|
||||
bad-names=
|
||||
|
||||
# Colon-delimited sets of names that determine each other's naming style when
|
||||
# the name regexes allow several styles.
|
||||
name-group=
|
||||
|
||||
# Include a hint for the correct naming format with invalid-name
|
||||
include-naming-hint=no
|
||||
|
||||
# List of decorators that produce properties, such as abc.abstractproperty. Add
|
||||
# to this list to register other decorators that produce valid properties.
|
||||
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
|
||||
|
||||
# Regular expression matching correct function names
|
||||
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression matching correct variable names
|
||||
variable-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct constant names
|
||||
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct attribute names
|
||||
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct argument names
|
||||
argument-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class attribute names
|
||||
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
|
||||
|
||||
# Regular expression matching correct inline iteration names
|
||||
inlinevar-rgx=^[a-z][a-z0-9_]*$
|
||||
|
||||
# Regular expression matching correct class names
|
||||
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
|
||||
|
||||
# Regular expression matching correct module names
|
||||
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
|
||||
|
||||
# Regular expression matching correct method names
|
||||
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
|
||||
|
||||
# Regular expression which should only match function or class names that do
|
||||
# not require a docstring.
|
||||
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
|
||||
|
||||
# Minimum line length for functions/classes that require docstrings, shorter
|
||||
# ones are exempt.
|
||||
docstring-min-length=10
|
||||
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
# List of decorators that produce context managers, such as
|
||||
# contextlib.contextmanager. Add to this list to register other decorators that
|
||||
# produce valid context managers.
|
||||
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
|
||||
|
||||
# Tells whether missing members accessed in mixin class should be ignored. A
|
||||
# mixin class is detected if its name ends with "mixin" (case insensitive).
|
||||
ignore-mixin-members=yes
|
||||
|
||||
# List of module names for which member attributes should not be checked
|
||||
# (useful for modules/projects where namespaces are manipulated during runtime
|
||||
# and thus existing member attributes cannot be deduced by static analysis. It
|
||||
# supports qualified module names, as well as Unix pattern matching.
|
||||
ignored-modules=
|
||||
|
||||
# List of class names for which member attributes should not be checked (useful
|
||||
# for classes with dynamically set attributes). This supports the use of
|
||||
# qualified names.
|
||||
ignored-classes=optparse.Values,thread._local,_thread._local
|
||||
|
||||
# List of members which are set dynamically and missed by pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed. Python regular
|
||||
# expressions are accepted.
|
||||
generated-members=
|
||||
|
||||
|
||||
[FORMAT]
|
||||
|
||||
# Maximum number of characters on a single line.
|
||||
max-line-length=80
|
||||
|
||||
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
|
||||
# lines made too long by directives to pytype.
|
||||
|
||||
# Regexp for a line that is allowed to be longer than the limit.
|
||||
ignore-long-lines=(?x)(
|
||||
^\s*(\#\ )?<?https?://\S+>?$|
|
||||
^\s*(from\s+\S+\s+)?import\s+.+$)
|
||||
|
||||
# Allow the body of an if to be on the same line as the test if there is no
|
||||
# else.
|
||||
single-line-if-stmt=yes
|
||||
|
||||
# Maximum number of lines in a module
|
||||
max-module-lines=99999
|
||||
|
||||
# String used as indentation unit. The internal Google style guide mandates 2
|
||||
# spaces. Google's externaly-published style guide says 4, consistent with
|
||||
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
|
||||
# projects (like TensorFlow).
|
||||
indent-string=' '
|
||||
|
||||
# Number of spaces of indent required inside a hanging or continued line.
|
||||
indent-after-paren=4
|
||||
|
||||
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
|
||||
expected-line-ending-format=
|
||||
|
||||
|
||||
[MISCELLANEOUS]
|
||||
|
||||
# List of note tags to take in consideration, separated by a comma.
|
||||
notes=TODO
|
||||
|
||||
|
||||
[STRING]
|
||||
|
||||
# This flag controls whether inconsistent-quotes generates a warning when the
|
||||
# character used as a quote delimiter is used inconsistently within a module.
|
||||
check-quote-consistency=yes
|
||||
|
||||
|
||||
[VARIABLES]
|
||||
|
||||
# Tells whether we should check for unused import in __init__ files.
|
||||
init-import=no
|
||||
|
||||
# A regular expression matching the name of dummy variables (i.e. expectedly
|
||||
# not used).
|
||||
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
|
||||
|
||||
# List of additional names supposed to be defined in builtins. Remember that
|
||||
# you should avoid to define new builtins when possible.
|
||||
additional-builtins=
|
||||
|
||||
# List of strings which can identify a callback function by name. A callback
|
||||
# name must start or end with one of those strings.
|
||||
callbacks=cb_,_cb
|
||||
|
||||
# List of qualified module names which can have objects that can redefine
|
||||
# builtins.
|
||||
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
|
||||
|
||||
|
||||
[LOGGING]
|
||||
|
||||
# Logging modules to check that the string format arguments are in logging
|
||||
# function parameter format
|
||||
logging-modules=logging,absl.logging,tensorflow.io.logging
|
||||
|
||||
|
||||
[SIMILARITIES]
|
||||
|
||||
# Minimum lines number of a similarity.
|
||||
min-similarity-lines=4
|
||||
|
||||
# Ignore comments when computing similarities.
|
||||
ignore-comments=yes
|
||||
|
||||
# Ignore docstrings when computing similarities.
|
||||
ignore-docstrings=yes
|
||||
|
||||
# Ignore imports when computing similarities.
|
||||
ignore-imports=no
|
||||
|
||||
|
||||
[SPELLING]
|
||||
|
||||
# Spelling dictionary name. Available dictionaries: none. To make it working
|
||||
# install python-enchant package.
|
||||
spelling-dict=
|
||||
|
||||
# List of comma separated words that should not be checked.
|
||||
spelling-ignore-words=
|
||||
|
||||
# A path to a file that contains private dictionary; one word per line.
|
||||
spelling-private-dict-file=
|
||||
|
||||
# Tells whether to store unknown words to indicated private dictionary in
|
||||
# --spelling-private-dict-file option instead of raising a message.
|
||||
spelling-store-unknown-words=no
|
||||
|
||||
|
||||
[IMPORTS]
|
||||
|
||||
# Deprecated modules which should not be used, separated by a comma
|
||||
deprecated-modules=regsub,
|
||||
TERMIOS,
|
||||
Bastion,
|
||||
rexec,
|
||||
sets
|
||||
|
||||
# Create a graph of every (i.e. internal and external) dependencies in the
|
||||
# given file (report RP0402 must not be disabled)
|
||||
import-graph=
|
||||
|
||||
# Create a graph of external dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
ext-import-graph=
|
||||
|
||||
# Create a graph of internal dependencies in the given file (report RP0402 must
|
||||
# not be disabled)
|
||||
int-import-graph=
|
||||
|
||||
# Force import order to recognize a module as part of the standard
|
||||
# compatibility libraries.
|
||||
known-standard-library=
|
||||
|
||||
# Force import order to recognize a module as part of a third party library.
|
||||
known-third-party=enchant, absl
|
||||
|
||||
# Analyse import fallback blocks. This can be used to support both Python 2 and
|
||||
# 3 compatible code, which means that the block might have code that exists
|
||||
# only in one or another interpreter, leading to false positives when analysed.
|
||||
analyse-fallback-blocks=no
|
||||
|
||||
|
||||
[CLASSES]
|
||||
|
||||
# List of method names used to declare (i.e. assign) instance attributes.
|
||||
defining-attr-methods=__init__,
|
||||
__new__,
|
||||
setUp
|
||||
|
||||
# List of member names, which should be excluded from the protected access
|
||||
# warning.
|
||||
exclude-protected=_asdict,
|
||||
_fields,
|
||||
_replace,
|
||||
_source,
|
||||
_make
|
||||
|
||||
# List of valid names for the first argument in a class method.
|
||||
valid-classmethod-first-arg=cls,
|
||||
class_
|
||||
|
||||
# List of valid names for the first argument in a metaclass class method.
|
||||
valid-metaclass-classmethod-first-arg=mcs
|
||||
|
||||
|
||||
[EXCEPTIONS]
|
||||
|
||||
# Exceptions that will emit a warning when being caught. Defaults to
|
||||
# "Exception"
|
||||
overgeneral-exceptions=StandardError,
|
||||
Exception,
|
||||
BaseException
|
105
Dockerfile
Normal file
@ -0,0 +1,105 @@
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip git
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
# this won't be needed for future versions of this docker image
|
||||
# or future versions of triton.
|
||||
RUN ldconfig /usr/local/cuda-12.1/compat/
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements.txt requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements.txt
|
||||
|
||||
# install development dependencies
|
||||
COPY requirements-dev.txt requirements-dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements-dev.txt
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
|
||||
|
||||
#################### EXTENSION BUILD IMAGE ####################
|
||||
FROM dev AS build
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
# copy input files
|
||||
COPY csrc csrc
|
||||
COPY setup.py setup.py
|
||||
COPY requirements.txt requirements.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm/__init__.py vllm/__init__.py
|
||||
|
||||
# cuda arch list used by torch
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# max jobs used by Ninja to build extensions
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
# make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
RUN python3 setup.py build_ext --inplace
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
# image to run unit testing suite
|
||||
FROM dev AS test
|
||||
|
||||
# copy pytorch extensions separately to avoid having to rebuild
|
||||
# when python code changes
|
||||
WORKDIR /vllm-workspace
|
||||
# ADD is used to preserve directory structure
|
||||
ADD . /vllm-workspace/
|
||||
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
|
||||
# ignore build dependencies installation because we are using pre-complied extensions
|
||||
RUN rm pyproject.toml
|
||||
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
|
||||
#################### RUNTIME BASE IMAGE ####################
|
||||
# We used base cuda image because pytorch installs its own cuda libraries.
|
||||
# However cupy depends on cuda libraries so we had to switch to the runtime image
|
||||
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
|
||||
|
||||
# libnccl required for ray
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip
|
||||
|
||||
WORKDIR /workspace
|
||||
COPY requirements.txt requirements.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -r requirements.txt
|
||||
#################### RUNTIME BASE IMAGE ####################
|
||||
|
||||
|
||||
#################### OPENAI API SERVER ####################
|
||||
# openai api server alternative
|
||||
FROM vllm-base AS vllm-openai
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate
|
||||
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
#################### OPENAI API SERVER ####################
|
95
Dockerfile.rocm
Normal file
@ -0,0 +1,95 @@
|
||||
# default base image
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
RUN echo "Base image is $BASE_IMAGE"
|
||||
|
||||
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
||||
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
||||
|
||||
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
||||
|
||||
ARG FA_BRANCH="3d2b6f5"
|
||||
RUN echo "FA_BRANCH is $FA_BRANCH"
|
||||
|
||||
# whether to build flash-attention
|
||||
# if 0, will not build flash attention
|
||||
# this is useful for gfx target where flash-attention is not supported
|
||||
# In that case, we need to use the python reference attention implementation in vllm
|
||||
ARG BUILD_FA="1"
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
nvidia-cuda-toolkit \
|
||||
tmux \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
ARG APP_MOUNT=/app
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
# Install ROCm flash-attention
|
||||
RUN if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout ${FA_BRANCH} \
|
||||
&& git submodule update --init \
|
||||
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
||||
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
||||
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..; \
|
||||
fi
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
||||
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually removed it so that later steps of numpy upgrade can continue
|
||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||
|
||||
RUN cd /app \
|
||||
&& cd vllm \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& if [ "$BUILD_FA" = "1" ]; then \
|
||||
bash patch_xformers.rocm.sh; fi \
|
||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir ray[all]
|
||||
|
||||
CMD ["/bin/bash"]
|
80
README.md
@ -10,18 +10,26 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> |
|
||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
||||
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||
- [2024/01] Added ROCm 6.0 support to vLLM.
|
||||
- [2023/12] Added ROCm 5.7 support to vLLM.
|
||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||
|
||||
---
|
||||
|
||||
## About
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
vLLM is fast with:
|
||||
@ -29,25 +37,45 @@ vLLM is fast with:
|
||||
- State-of-the-art serving throughput
|
||||
- Efficient management of attention key and value memory with **PagedAttention**
|
||||
- Continuous batching of incoming requests
|
||||
- Fast model execution with CUDA/HIP graph
|
||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
||||
- Optimized CUDA kernels
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular HuggingFace models
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs and AMD GPUs
|
||||
- (Experimental) Prefix caching support
|
||||
- (Experimental) Multi-lora support
|
||||
|
||||
vLLM seamlessly supports many Huggingface models, including the following architectures:
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
||||
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
|
||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
||||
|
||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
|
||||
@ -62,37 +90,19 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
|
||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
||||
|
||||
## Performance
|
||||
|
||||
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
|
||||
For details, check out our [blog post](https://vllm.ai).
|
||||
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_light.png" width="45%">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_light.png" width="45%">
|
||||
</picture>
|
||||
<br>
|
||||
<em> Serving throughput when each request asks for 1 output completion. </em>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_light.png" width="45%">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_dark.png">
|
||||
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_light.png" width="45%">
|
||||
</picture> <br>
|
||||
<em> Serving throughput when each request asks for 3 output completions. </em>
|
||||
</p>
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome and value any contributions and collaborations.
|
||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
||||
|
||||
## Citation
|
||||
|
||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||
```bibtex
|
||||
@inproceedings{kwon2023efficient,
|
||||
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
284
benchmarks/backend_request_func.py
Normal file
@ -0,0 +1,284 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncInput:
|
||||
prompt: str
|
||||
api_url: str
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
best_of: int = 1
|
||||
use_beam_search: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0
|
||||
ttft: float = 0
|
||||
prompt_len: int = 0
|
||||
|
||||
|
||||
async def async_request_tgi(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||
}
|
||||
payload = {
|
||||
"inputs": request_func_input.prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for data in response.content.iter_any():
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
body = data.decode("utf-8").lstrip("data:")
|
||||
output.generated_text = json.loads(body)["generated_text"]
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_vllm(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
"n": 1,
|
||||
"best_of": request_func_input.best_of,
|
||||
"use_beam_search": request_func_input.use_beam_search,
|
||||
"temperature": 0.0 if request_func_input.use_beam_search else 1.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"ignore_eos": True,
|
||||
"stream": True,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for data in response.content.iter_any():
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
# When streaming, '\0' is appended to the end of the response.
|
||||
body = data.decode("utf-8").strip("\0")
|
||||
output.generated_text = json.loads(
|
||||
body)["text"][0][len(request_func_input.prompt):]
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_trt_llm(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
ttft = 0
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
async for data in resp.content.iter_any():
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
output.latency = time.perf_counter() - st
|
||||
|
||||
body = data.decode("utf-8").lstrip("data:")
|
||||
output.generated_text = json.loads(body)["text_output"]
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_deepspeed_mii(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
assert not request_func_input.use_beam_search
|
||||
|
||||
payload = {
|
||||
"prompts": request_func_input.prompt,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"ignore_eos": True,
|
||||
"do_sample": True,
|
||||
"temperature":
|
||||
0.01, # deepspeed-mii does not accept 0.0 temperature.
|
||||
"top_p": 1.0,
|
||||
}
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
|
||||
# https://github.com/microsoft/DeepSpeed-MII/pull/311
|
||||
output.ttft = 0
|
||||
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=request_func_input.api_url,
|
||||
json=payload) as resp:
|
||||
if resp.status == 200:
|
||||
parsed_resp = await resp.json()
|
||||
output.latency = time.perf_counter() - st
|
||||
output.generated_text = parsed_resp[0]["generated_text"]
|
||||
output.success = True
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("v1/completions")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert not request_func_input.use_beam_search
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
||||
}
|
||||
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
st = time.perf_counter()
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
if ttft == 0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
chunk = chunk.decode("utf-8").lstrip("data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
body = json.loads(chunk)
|
||||
generated_text += body["choices"][0]["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.success = False
|
||||
except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
|
||||
output.success = False
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
"tgi": async_request_tgi,
|
||||
"vllm": async_request_vllm,
|
||||
"deepspeed-mii": async_request_deepspeed_mii,
|
||||
"openai": async_request_openai_completions,
|
||||
"tensorrt-llm": async_request_trt_llm,
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,15 +14,18 @@ from vllm import LLM, SamplingParams
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
# Process all the requests in a single batch if possible.
|
||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
||||
# the engine will automatically process the request in multiple batches.
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@ -32,47 +37,113 @@ def main(args: argparse.Namespace):
|
||||
max_tokens=args.output_len,
|
||||
)
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
||||
|
||||
def run_to_completion(profile: bool = False):
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.time()
|
||||
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
|
||||
end_time = time.time()
|
||||
latency = end_time - start_time
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
return latency
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print(p.key_averages())
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
run_to_completion(profile=False)
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
profile_dir = args.profile_result_dir
|
||||
if not profile_dir:
|
||||
profile_dir = Path(
|
||||
"."
|
||||
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=profile_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile=False))
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--n', type=int, default=1,
|
||||
parser.add_argument('--n',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of generated sequences per prompt.')
|
||||
parser.add_argument('--use-beam-search', action='store_true')
|
||||
parser.add_argument('--num-iters', type=int, default=3,
|
||||
parser.add_argument('--num-iters',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default='auto',
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument(
|
||||
'--profile',
|
||||
action='store_true',
|
||||
help='profile the generation process of a single batch')
|
||||
parser.add_argument(
|
||||
'--profile-result-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help=('path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard.'))
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -20,15 +20,36 @@ import asyncio
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# (prompt len, output len, latency)
|
||||
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
|
||||
from backend_request_func import (
|
||||
ASYNC_REQUEST_FUNCS,
|
||||
RequestFuncInput,
|
||||
RequestFuncOutput,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
total_output: int
|
||||
request_throughput: float
|
||||
input_throughput: float
|
||||
output_throughput: float
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
p99_ttft_ms: float
|
||||
mean_tpot_ms: float
|
||||
median_tpot_ms: float
|
||||
p99_tpot_ms: float
|
||||
|
||||
|
||||
def sample_requests(
|
||||
@ -40,15 +61,15 @@ def sample_requests(
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# some of these will be filtered out, so sample more than we need
|
||||
sampled_indices = random.sample(range(len(dataset)),
|
||||
int(num_requests * 1.2))
|
||||
dataset = [dataset[i] for i in sampled_indices]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
@ -96,79 +117,125 @@ async def get_request(
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def send_request(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
prompt: str,
|
||||
prompt_len: int,
|
||||
output_len: int,
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
) -> None:
|
||||
request_start_time = time.time()
|
||||
def calculate_metrics(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
outputs: List[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> BenchmarkMetrics:
|
||||
total_output = 0
|
||||
total_input = 0
|
||||
completed = 0
|
||||
per_token_latencies = []
|
||||
ttfts = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
output_len = len(tokenizer.encode(outputs[i].generated_text))
|
||||
total_output += output_len
|
||||
total_input += input_requests[i][1]
|
||||
per_token_latencies.append(outputs[i].latency / output_len)
|
||||
ttfts.append(outputs[i].ttft)
|
||||
completed += 1
|
||||
|
||||
headers = {"User-Agent": "Benchmark Client"}
|
||||
if backend == "vllm":
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"n": 1,
|
||||
"best_of": best_of,
|
||||
"use_beam_search": use_beam_search,
|
||||
"temperature": 0.0 if use_beam_search else 1.0,
|
||||
"top_p": 1.0,
|
||||
"max_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"stream": False,
|
||||
}
|
||||
elif backend == "tgi":
|
||||
assert not use_beam_search
|
||||
params = {
|
||||
"best_of": best_of,
|
||||
"max_new_tokens": output_len,
|
||||
"do_sample": True,
|
||||
}
|
||||
pload = {
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
metrics = BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
total_output=total_output,
|
||||
request_throughput=completed / dur_s,
|
||||
input_throughput=total_input / dur_s,
|
||||
output_throughput=total_output / dur_s,
|
||||
mean_ttft_ms=np.mean(ttfts) * 1000,
|
||||
median_ttft_ms=np.median(ttfts) * 1000,
|
||||
p99_ttft_ms=np.percentile(ttfts, 99) * 1000,
|
||||
mean_tpot_ms=np.mean(per_token_latencies) * 1000,
|
||||
median_tpot_ms=np.median(per_token_latencies) * 1000,
|
||||
p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000,
|
||||
)
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
while True:
|
||||
async with session.post(api_url, headers=headers, json=pload) as response:
|
||||
chunks = []
|
||||
async for chunk, _ in response.content.iter_chunks():
|
||||
chunks.append(chunk)
|
||||
output = b"".join(chunks).decode("utf-8")
|
||||
output = json.loads(output)
|
||||
|
||||
# Re-send the request if it failed.
|
||||
if "error" not in output:
|
||||
break
|
||||
|
||||
request_end_time = time.time()
|
||||
request_latency = request_end_time - request_start_time
|
||||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
|
||||
return metrics
|
||||
|
||||
|
||||
async def benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
model_id: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
best_of: int,
|
||||
use_beam_search: bool,
|
||||
request_rate: float,
|
||||
) -> None:
|
||||
tasks: List[asyncio.Task] = []
|
||||
disable_tqdm: bool,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS.get(backend)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
prompt, prompt_len, output_len = request
|
||||
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
||||
prompt_len, output_len,
|
||||
best_of, use_beam_search))
|
||||
tasks.append(task)
|
||||
await asyncio.gather(*tasks)
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
best_of=best_of,
|
||||
use_beam_search=use_beam_search,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
request_func(request_func_input=request_func_input,
|
||||
pbar=pbar)))
|
||||
outputs = await asyncio.gather(*tasks)
|
||||
|
||||
if not disable_tqdm:
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
metrics = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
print(f"Successful requests: {metrics.completed}")
|
||||
print(f"Benchmark duration: {benchmark_duration:2f} s")
|
||||
print(f"Total input tokens: {metrics.total_input}")
|
||||
print(f"Total generated tokens: {metrics.total_output}")
|
||||
print(f"Request throughput: {metrics.request_throughput:.2f} requests/s")
|
||||
print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s")
|
||||
print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s")
|
||||
print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms")
|
||||
print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms")
|
||||
print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms")
|
||||
print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms")
|
||||
print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms")
|
||||
print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms")
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_inthroughput": metrics.request_throughput,
|
||||
"input_throughput": metrics.input_throughput,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
||||
"median_ttft_ms": metrics.median_ttft_ms,
|
||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
||||
"median_tpot_ms": metrics.median_tpot_ms,
|
||||
"p99_tpot_ms": metrics.p99_tpot_ms
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -176,56 +243,145 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
tokenizer = get_tokenizer(args.tokenizer)
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
benchmark_start_time = time.time()
|
||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
||||
args.use_beam_search, args.request_rate))
|
||||
benchmark_end_time = time.time()
|
||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||
print(f"Total time: {benchmark_time:.2f} s")
|
||||
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
best_of=args.best_of,
|
||||
use_beam_search=args.use_beam_search,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
))
|
||||
|
||||
# Compute the latency statistics.
|
||||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
||||
print(f"Average latency: {avg_latency:.2f} s")
|
||||
avg_per_token_latency = np.mean([
|
||||
latency / (prompt_len + output_len)
|
||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||
avg_per_output_token_latency = np.mean([
|
||||
latency / output_len
|
||||
for _, output_len, latency in REQUEST_LATENCY
|
||||
])
|
||||
print("Average latency per output token: "
|
||||
f"{avg_per_output_token_latency:.2f} s")
|
||||
# Save config and results to json
|
||||
if args.save_result:
|
||||
result_json = {}
|
||||
|
||||
# Setup
|
||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
result_json["date"] = current_dt
|
||||
result_json["backend"] = backend
|
||||
result_json["version"] = args.version
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["use_beam_search"] = args.use_beam_search
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (
|
||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
# Save to file
|
||||
base_model_id = model_id.split("/")[-1]
|
||||
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
|
||||
with open(file_name, "w") as outfile:
|
||||
json.dump(result_json, outfile)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
parser.add_argument("--backend", type=str, default="vllm",
|
||||
choices=["vllm", "tgi"])
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default="vllm",
|
||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="N/A",
|
||||
help="Version of the serving backend/engine.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server or API base url if not using http host and port.",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
parser.add_argument(
|
||||
"--endpoint",
|
||||
type=str,
|
||||
default="/generate",
|
||||
help="API endpoint.",
|
||||
)
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--tokenizer", type=str, required=True,
|
||||
help="Name or path of the tokenizer.")
|
||||
parser.add_argument("--best-of", type=int, default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Name of the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer",
|
||||
type=str,
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default model tokenizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, "
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.")
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
default=float("inf"),
|
||||
help="Number of requests per second. If this is inf, "
|
||||
"then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize "
|
||||
"the request arrival times.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Trust remote code from huggingface",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
action="store_true",
|
||||
help="Specify to disbale tqdm progress bar.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-result",
|
||||
action="store_true",
|
||||
help="Specify to save benchmark results to a json file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -3,34 +3,31 @@ import argparse
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int],
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
@ -40,6 +37,8 @@ def sample_requests(
|
||||
tokenized_dataset = []
|
||||
for i in range(len(dataset)):
|
||||
output_len = len(completion_token_ids[i])
|
||||
if fixed_output_len is not None:
|
||||
output_len = fixed_output_len
|
||||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
|
||||
|
||||
# Filter out too long sequences.
|
||||
@ -63,16 +62,31 @@ def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
trust_remote_code: bool,
|
||||
dtype: str,
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -92,10 +106,10 @@ def run_vllm(
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
# FIXME(woosuk): Do use internal method.
|
||||
start = time.perf_counter()
|
||||
# FIXME(woosuk): Do not use internal method.
|
||||
llm._run_engine(use_tqdm=True)
|
||||
end = time.time()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
@ -106,16 +120,18 @@ def run_hf(
|
||||
n: int,
|
||||
use_beam_search: bool,
|
||||
max_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
) -> float:
|
||||
assert not use_beam_search
|
||||
llm = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
llm = llm.cuda()
|
||||
|
||||
pbar = tqdm(total=len(requests))
|
||||
start = time.time()
|
||||
start = time.perf_counter()
|
||||
batch: List[str] = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
@ -128,13 +144,14 @@ def run_hf(
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
if (max(max_prompt_len, next_prompt_len) + max(
|
||||
max_output_len, next_output_len)) <= 2048:
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
input_ids = tokenizer(batch, return_tensors="pt",
|
||||
padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=not use_beam_search,
|
||||
@ -152,7 +169,23 @@ def run_hf(
|
||||
batch = []
|
||||
max_prompt_len = 0
|
||||
max_output_len = 0
|
||||
end = time.time()
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
def run_mii(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
output_len: int,
|
||||
) -> float:
|
||||
from mii import pipeline
|
||||
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
|
||||
prompts = [prompt for prompt, _, _ in requests]
|
||||
|
||||
start = time.perf_counter()
|
||||
llm(prompts, max_new_tokens=output_len)
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
||||
@ -161,45 +194,122 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = get_tokenizer(args.tokenizer)
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
if args.dataset is None:
|
||||
# Synthesize a prompt with the given input length.
|
||||
prompt = "hi" * (args.input_len - 1)
|
||||
requests = [(prompt, args.input_len, args.output_len)
|
||||
for _ in range(args.num_prompts)]
|
||||
else:
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search)
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype, args.device)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.use_beam_search, args.hf_max_batch_size)
|
||||
args.use_beam_search, args.hf_max_batch_size,
|
||||
args.trust_remote_code)
|
||||
elif args.backend == "mii":
|
||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
||||
args.output_len)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(
|
||||
prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests
|
||||
)
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests)
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf", "mii"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--input-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Input prompt length for each request")
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n", type=int, default=1,
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||
parser.add_argument("--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument(
|
||||
'--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Maximum length of a sequence (including prompt and output). '
|
||||
'If None, will be derived from the model.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='auto',
|
||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument("--enforce-eager",
|
||||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
if args.dataset is None:
|
||||
assert args.input_len is not None
|
||||
assert args.output_len is not None
|
||||
else:
|
||||
assert args.input_len is None
|
||||
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
@ -207,7 +317,20 @@ if __name__ == "__main__":
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
elif args.backend == "mii":
|
||||
if args.dtype != "auto":
|
||||
raise ValueError("dtype must be auto for MII backend.")
|
||||
if args.n != 1:
|
||||
raise ValueError("n must be 1 for MII backend.")
|
||||
if args.use_beam_search:
|
||||
raise ValueError("Beam search is not supported for MII backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
main(args)
|
||||
|
205
benchmarks/kernels/benchmark_paged_attention.py
Normal file
@ -0,0 +1,205 @@
|
||||
from typing import Optional
|
||||
import argparse
|
||||
import random
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
from vllm._C import ops
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
PARTITION_SIZE = 512
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main(
|
||||
version: str,
|
||||
num_seqs: int,
|
||||
context_len: int,
|
||||
num_query_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
use_alibi: bool,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
do_profile: bool,
|
||||
device: str = "cuda",
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device=device)
|
||||
|
||||
context_lens = [context_len for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
|
||||
|
||||
# Create the KV cache.
|
||||
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
output = torch.empty_like(query)
|
||||
if version == "v2":
|
||||
num_partitions = ((max_context_len + PARTITION_SIZE - 1) //
|
||||
PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_query_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
return (end_time - start_time) / num_iters
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=3, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
if do_profile:
|
||||
latency = run_benchmark(num_iters=1, profile=True)
|
||||
else:
|
||||
latency = run_benchmark(num_iters=100, profile=False)
|
||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the paged attention kernel.")
|
||||
parser.add_argument("--version",
|
||||
type=str,
|
||||
choices=["v1", "v2"],
|
||||
default="v2")
|
||||
parser.add_argument("--batch-size", type=int, default=8)
|
||||
parser.add_argument("--context-len", type=int, default=4096)
|
||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 256],
|
||||
default=128)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
parser.add_argument("--dtype",
|
||||
type=str,
|
||||
choices=["half", "bfloat16", "float"],
|
||||
default="half")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--profile", action="store_true")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8_e5m2"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.num_query_heads % args.num_kv_heads != 0:
|
||||
raise ValueError("num_query_heads must be divisible by num_kv_heads")
|
||||
main(
|
||||
version=args.version,
|
||||
num_seqs=args.batch_size,
|
||||
context_len=args.context_len,
|
||||
num_query_heads=args.num_query_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
head_size=args.head_size,
|
||||
block_size=args.block_size,
|
||||
use_alibi=args.use_alibi,
|
||||
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||
seed=args.seed,
|
||||
do_profile=args.profile,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
)
|
@ -6,7 +6,7 @@ TOKENS=$2
|
||||
|
||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
||||
-v $PWD/data:/data \
|
||||
ghcr.io/huggingface/text-generation-inference:0.8 \
|
||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
||||
--model-id $MODEL \
|
||||
--sharded false \
|
||||
--max-input-length 1024 \
|
||||
|
@ -1,12 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -11,13 +15,13 @@ __device__ __forceinline__ T silu(const T& x) {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void silu_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, d]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, 2, d]
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
const int d) {
|
||||
const int token_idx = blockIdx.x;
|
||||
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
}
|
||||
}
|
||||
@ -25,18 +29,17 @@ __global__ void silu_and_mul_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
torch::Tensor& input) // [num_tokens, 2 * d]
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
int num_tokens = input.size(0);
|
||||
int d = input.size(1) / 2;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int d = input.size(-1) / 2;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(d, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
@ -46,3 +49,70 @@ void silu_and_mul(
|
||||
d);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Element-wise activation kernel template.
|
||||
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
__global__ void activation_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., d]
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// Launch element-wise activation kernel.
|
||||
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
|
||||
int d = input.size(-1); \
|
||||
int64_t num_tokens = input.numel() / d; \
|
||||
dim3 grid(num_tokens); \
|
||||
dim3 block(std::min(d, 1024)); \
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), \
|
||||
"activation_kernel", \
|
||||
[&] { \
|
||||
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), \
|
||||
d); \
|
||||
});
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
|
||||
const float x3 = (float) (x * x * x);
|
||||
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
|
||||
const float f = (float) x;
|
||||
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
|
||||
return ((T) 0.5) * x * (((T) 1.0) + t);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
|
||||
}
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
|
||||
}
|
||||
|
@ -1,21 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
void single_query_cached_kv_attention(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"single_query_cached_kv_attention",
|
||||
&single_query_cached_kv_attention,
|
||||
"Compute the attention between an input query and the cached key/value tensors");
|
||||
}
|
@ -4,3 +4,4 @@
|
||||
#include "dtype_float16.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
#include "dtype_bfloat16.cuh"
|
||||
#include "dtype_fp8_e5m2.cuh"
|
||||
|
@ -15,17 +15,30 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
namespace vllm {
|
||||
|
||||
@ -39,7 +52,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
@ -58,32 +71,66 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
return VLLM_SHFL_SYNC(sum, 0);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs).
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS>
|
||||
__global__ void single_query_cached_kv_attention_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
||||
__device__ void paged_attention_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride) {
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
const int max_num_partitions = gridDim.z;
|
||||
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
|
||||
// No work to do. Terminate the thread block.
|
||||
return;
|
||||
}
|
||||
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
||||
|
||||
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
||||
const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
||||
const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
||||
const int num_blocks = end_block_idx - start_block_idx;
|
||||
|
||||
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
||||
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
||||
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
||||
const int num_tokens = end_token_idx - start_token_idx;
|
||||
|
||||
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
||||
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
||||
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
||||
constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int thread_idx = threadIdx.x;
|
||||
const int warp_idx = thread_idx / WARP_SIZE;
|
||||
@ -91,7 +138,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
const int head_idx = blockIdx.x;
|
||||
const int num_heads = gridDim.x;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
@ -102,6 +150,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
||||
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
||||
#endif
|
||||
|
||||
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
||||
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
||||
@ -116,12 +167,13 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// th vectors of the query, and so on.
|
||||
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
|
||||
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
||||
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
||||
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
||||
}
|
||||
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
|
||||
|
||||
// Memory planning.
|
||||
extern __shared__ char shared_mem[];
|
||||
@ -132,19 +184,19 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
// x == THREAD_GROUP_SIZE * VEC_SIZE
|
||||
// Each thread group fetches x elements from the key at a time.
|
||||
constexpr int x = 16 / sizeof(scalar_t);
|
||||
constexpr int x = 16 / sizeof(cache_t);
|
||||
float qk_max = -FLT_MAX;
|
||||
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
|
||||
// Iterate over the key blocks.
|
||||
// Each warp fetches a block of keys for each iteration.
|
||||
// Each thread group in a warp fetches a key from the block, and computes
|
||||
// dot product with the query.
|
||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||
const int physical_block_number = block_table[block_idx];
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
|
||||
// Load a key to registers.
|
||||
// Each thread in a thread group has a different part of the key.
|
||||
@ -158,26 +210,36 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE
|
||||
+ physical_block_offset * x;
|
||||
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
||||
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
||||
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
// Vector conversion from Quant_vec to K_vec.
|
||||
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute dot product.
|
||||
// This includes a reduction across the threads in the same thread group.
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
// Add the ALiBi bias if slopes are given.
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
|
||||
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
||||
|
||||
if (thread_group_offset == 0) {
|
||||
// Store the partial reductions to shared memory.
|
||||
// NOTE(woosuk): It is required to zero out the masked logits.
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx] = mask ? 0.f : qk;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
// Update the max value.
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
@ -189,7 +251,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
// The 0-th thread of each thread group already has its max qk value.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
@ -201,14 +263,14 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
// Broadcast the max qk value to all threads.
|
||||
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
@ -217,20 +279,35 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
// Compute softmax.
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// If partitioning is enabled, store the max logit and exp_sum.
|
||||
if (USE_PARTITIONING && thread_idx == 0) {
|
||||
float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
*max_logits_ptr = qk_max;
|
||||
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions
|
||||
+ partition_idx;
|
||||
*exp_sums_ptr = exp_sum;
|
||||
}
|
||||
|
||||
// Each thread will fetch 16 bytes from the value cache at a time.
|
||||
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
||||
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
||||
#endif
|
||||
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
||||
|
||||
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
||||
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
||||
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
|
||||
constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
||||
|
||||
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
@ -239,21 +316,47 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
accs[i] = 0.f;
|
||||
}
|
||||
|
||||
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
|
||||
const int physical_block_number = block_table[block_idx];
|
||||
scalar_t zero_value;
|
||||
zero(zero_value);
|
||||
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
||||
// NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
|
||||
// because int32 can lead to overflow when this variable is multiplied by large numbers
|
||||
// (e.g., kv_block_stride).
|
||||
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
||||
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
||||
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
||||
L_vec logits_vec;
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
||||
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
||||
|
||||
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
||||
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
|
||||
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE) {
|
||||
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
||||
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
V_vec v_vec;
|
||||
if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
||||
}
|
||||
if (block_idx == num_context_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
|
||||
// we should explicitly zero out the values since they may contain NaNs.
|
||||
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
|
||||
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < V_VEC_SIZE; j++) {
|
||||
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
||||
}
|
||||
}
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
@ -265,7 +368,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
float acc = accs[i];
|
||||
#pragma unroll
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
@ -308,7 +411,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
|
||||
// Write the final output.
|
||||
if (warp_idx == 0) {
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
@ -319,32 +424,199 @@ __global__ void single_query_cached_kv_attention_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, 1).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE>
|
||||
__global__ void paged_attention_v1_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs, max_num_partitions).
|
||||
template<
|
||||
typename scalar_t,
|
||||
typename cache_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_kernel(
|
||||
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride,
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
|
||||
// Grid: (num_heads, num_seqs).
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE>
|
||||
__global__ void paged_attention_v2_reduce_kernel(
|
||||
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int max_num_partitions) {
|
||||
const int num_heads = gridDim.x;
|
||||
const int head_idx = blockIdx.x;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
if (num_partitions == 1) {
|
||||
// No need to reduce. Only copy tmp_out to out.
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
||||
out_ptr[i] = tmp_out_ptr[i];
|
||||
}
|
||||
// Terminate the thread block.
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
const int warp_idx = threadIdx.x / WARP_SIZE;
|
||||
const int lane = threadIdx.x % WARP_SIZE;
|
||||
|
||||
// Size: 2 * num_partitions.
|
||||
extern __shared__ char shared_mem[];
|
||||
// Workspace for reduction.
|
||||
__shared__ float red_smem[2 * NUM_WARPS];
|
||||
|
||||
// Load max logits to shared memory.
|
||||
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
||||
const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
float max_logit = -FLT_MAX;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
const float l = max_logits_ptr[i];
|
||||
shared_max_logits[i] = l;
|
||||
max_logit = fmaxf(max_logit, l);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Get the global max logit.
|
||||
// Reduce within the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = max_logit;
|
||||
}
|
||||
__syncthreads();
|
||||
// Reduce across warps.
|
||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
// Broadcast the max value to all threads.
|
||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
||||
+ head_idx * max_num_partitions;
|
||||
float global_exp_sum = 0.0f;
|
||||
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
||||
float l = shared_max_logits[i];
|
||||
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
|
||||
global_exp_sum += rescaled_exp_sum;
|
||||
shared_exp_sums[i] = rescaled_exp_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
|
||||
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
||||
|
||||
// Aggregate tmp_out to out.
|
||||
const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE;
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
||||
#pragma unroll
|
||||
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
||||
float acc = 0.0f;
|
||||
for (int j = 0; j < num_partitions; ++j) {
|
||||
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
||||
}
|
||||
from_float(out_ptr[i], acc);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
query_stride);
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride);
|
||||
|
||||
// TODO(woosuk): Tune NUM_THREADS.
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int NUM_THREADS = 128>
|
||||
void single_query_cached_kv_attention_launcher(
|
||||
void paged_attention_v1_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@ -354,7 +626,9 @@ void single_query_cached_kv_attention_launcher(
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int query_stride = query.stride(0);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
@ -366,126 +640,314 @@ void single_query_cached_kv_attention_launcher(
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
||||
int logits_size = padded_max_context_len * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
|
||||
// Keep that in sync with the logic here!
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
|
||||
dim3 grid(num_heads, num_seqs);
|
||||
dim3 grid(num_heads, num_seqs, 1);
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
||||
// 32, 160, 192, 256.
|
||||
// case 32:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 64:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(64);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(80);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(96);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(112);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
|
||||
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_PAGED_ATTENTION_V1(256);
|
||||
break;
|
||||
// case 160:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 192:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
// case 256:
|
||||
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
||||
// break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
|
||||
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
||||
switch (block_size) { \
|
||||
/* case 1: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
|
||||
/* break; */ \
|
||||
/* case 2: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
|
||||
/* break; */ \
|
||||
/* case 4: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
|
||||
/* break; */ \
|
||||
case 8: \
|
||||
CALL_KERNEL_LAUNCHER(T, 8); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_KERNEL_LAUNCHER(T, 16); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_KERNEL_LAUNCHER(T, 32); \
|
||||
break; \
|
||||
/* case 64: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
|
||||
/* break; */ \
|
||||
/* case 128: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
|
||||
/* break; */ \
|
||||
/* case 256: */ \
|
||||
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
|
||||
/* break; */ \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void single_query_cached_kv_attention(
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
||||
vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
||||
IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, \
|
||||
q_stride, \
|
||||
kv_block_stride, \
|
||||
kv_head_stride); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
||||
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
exp_sums_ptr, \
|
||||
max_logits_ptr, \
|
||||
tmp_out_ptr, \
|
||||
context_lens_ptr, \
|
||||
max_num_partitions);
|
||||
|
||||
template<
|
||||
typename T,
|
||||
typename CACHE_T,
|
||||
int BLOCK_SIZE,
|
||||
bool IS_FP8_E5M2_KV_CACHE,
|
||||
int NUM_THREADS = 128,
|
||||
int PARTITION_SIZE = 512>
|
||||
void paged_attention_v2_launcher(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
int max_num_blocks_per_seq = block_tables.size(1);
|
||||
int q_stride = query.stride(0);
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
const float* alibi_slopes_ptr = alibi_slopes ?
|
||||
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
||||
: nullptr;
|
||||
|
||||
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
||||
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
||||
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
|
||||
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||
int logits_size = PARTITION_SIZE * sizeof(float);
|
||||
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
||||
|
||||
// For paged attention v2 kernel.
|
||||
dim3 grid(num_heads, num_seqs, max_num_partitions);
|
||||
int shared_mem_size = std::max(logits_size, outputs_size);
|
||||
// For paged attention v2 reduce kernel.
|
||||
dim3 reduce_grid(num_heads, num_seqs);
|
||||
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
switch (head_size) {
|
||||
// NOTE(woosuk): To reduce the compilation time, we only compile for the
|
||||
// head sizes that we use in the model. However, we can easily extend this
|
||||
// to support any head size which is a multiple of 16.
|
||||
case 64:
|
||||
LAUNCH_PAGED_ATTENTION_V2(64);
|
||||
break;
|
||||
case 80:
|
||||
LAUNCH_PAGED_ATTENTION_V2(80);
|
||||
break;
|
||||
case 96:
|
||||
LAUNCH_PAGED_ATTENTION_V2(96);
|
||||
break;
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V2(112);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_PAGED_ATTENTION_V2(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
||||
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
||||
out, \
|
||||
exp_sums, \
|
||||
max_logits, \
|
||||
tmp_out, \
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
max_context_len, \
|
||||
alibi_slopes);
|
||||
|
||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
||||
// 1, 2, 4, 64, 128, 256.
|
||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
switch (block_size) { \
|
||||
case 8: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 16: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
case 32: \
|
||||
CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||
break; \
|
||||
}
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype) {
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (query.dtype() == at::ScalarType::Float) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::Half) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
||||
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -17,6 +17,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "attention_dtypes.h"
|
||||
|
||||
#include <float.h>
|
||||
@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
float qk = sum(qk_vec);
|
||||
#pragma unroll
|
||||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
||||
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
||||
}
|
||||
return qk;
|
||||
}
|
||||
|
@ -21,8 +21,17 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return a + b;
|
||||
#ifndef USE_ROCM
|
||||
return a + b;
|
||||
#else
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -420,4 +433,19 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// From bfloat16 to float32.
|
||||
inline __device__ float to_float(__nv_bfloat16 u) {
|
||||
return __bfloat162float(u);
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(__nv_bfloat16& dst) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
|
||||
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -21,6 +21,10 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -63,21 +67,47 @@ struct FloatVec<uint4> {
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||
#ifndef USE_ROCM
|
||||
uint32_t b;
|
||||
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||
return b;
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u16[0] = a;
|
||||
tmp.u16[1] = a;
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ float half_to_float(uint16_t h) {
|
||||
float f;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||
#else
|
||||
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
||||
#endif
|
||||
return f;
|
||||
}
|
||||
|
||||
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||
#ifndef USE_ROCM
|
||||
uint16_t lo, hi;
|
||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u32 = v;
|
||||
float2 ret;
|
||||
ret.x = half_to_float(tmp.u16[0]);
|
||||
ret.y = half_to_float(tmp.u16[1]);
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ uint16_t float_to_half(float f) {
|
||||
@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||
#else
|
||||
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
||||
#endif
|
||||
return tmp.u16[0];
|
||||
}
|
||||
|
||||
@ -94,12 +128,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
#endif
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
tmp.u16[0] = float_to_half(f.x);
|
||||
tmp.u16[1] = float_to_half(f.y);
|
||||
#endif
|
||||
return tmp.u32;
|
||||
}
|
||||
@ -107,13 +145,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
// Vector addition.
|
||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||
template<>
|
||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
|
||||
@ -390,11 +448,6 @@ inline __device__ float sum(uint4 v) {
|
||||
return sum(c);
|
||||
}
|
||||
|
||||
// Zero-out a vector.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
|
||||
// From float32 to float16.
|
||||
inline __device__ void from_float(uint16_t& dst, float src) {
|
||||
dst = float_to_half(src);
|
||||
@ -441,4 +494,9 @@ inline __device__ Float8_ to_float(uint4 u) {
|
||||
return tmp;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(uint16_t& dst) {
|
||||
dst = uint16_t(0);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -265,4 +265,9 @@ inline __device__ Float8_ to_float(Float8_ u) {
|
||||
return u;
|
||||
}
|
||||
|
||||
// Zero-out a variable.
|
||||
inline __device__ void zero(float& dst) {
|
||||
dst = 0.f;
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
35
csrc/attention/dtype_fp8_e5m2.cuh
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include "attention_generic.cuh"
|
||||
|
||||
#include <stdint.h>
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
// fp8 vector types for quantization of kv cache
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 1> {
|
||||
using Type = uint8_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 2> {
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 4> {
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Vec<uint8_t, 8> {
|
||||
using Type = uint2;
|
||||
};
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
|
||||
} // namespace vllm
|
@ -1,3 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <map>
|
||||
@ -18,7 +20,8 @@ void reshape_and_cache(
|
||||
torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key,
|
||||
@ -27,21 +30,7 @@ void gather_cached_kv(
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
m.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
m.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
m.def(
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
}
|
||||
// Just for unittest
|
||||
void convert_fp8_e5m2(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache);
|
@ -1,11 +1,23 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
@ -26,10 +38,11 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
@ -53,26 +66,26 @@ template<typename scalar_t>
|
||||
__global__ void copy_blocks_kernel(
|
||||
int64_t* key_cache_ptrs,
|
||||
int64_t* value_cache_ptrs,
|
||||
const int* __restrict__ block_mapping,
|
||||
const int64_t* __restrict__ block_mapping,
|
||||
const int numel_per_block) {
|
||||
const int layer_idx = blockIdx.x;
|
||||
const int pair_idx = blockIdx.y;
|
||||
|
||||
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
|
||||
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
|
||||
int src_block_number = block_mapping[2 * pair_idx];
|
||||
int dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
int64_t src_block_number = block_mapping[2 * pair_idx];
|
||||
int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
|
||||
|
||||
const int src_block_offset = src_block_number * numel_per_block;
|
||||
const int dst_block_offset = dst_block_number * numel_per_block;
|
||||
const int64_t src_block_offset = src_block_number * numel_per_block;
|
||||
const int64_t dst_block_offset = dst_block_number * numel_per_block;
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int src_offset = src_block_offset + i;
|
||||
int dst_offset = dst_block_offset + i;
|
||||
int64_t src_offset = src_block_offset + i;
|
||||
int64_t dst_offset = dst_block_offset + i;
|
||||
key_cache[dst_offset] = key_cache[src_offset];
|
||||
}
|
||||
for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
|
||||
int src_offset = src_block_offset + i;
|
||||
int dst_offset = dst_block_offset + i;
|
||||
int64_t src_offset = src_block_offset + i;
|
||||
int64_t dst_offset = dst_block_offset + i;
|
||||
value_cache[dst_offset] = value_cache[src_offset];
|
||||
}
|
||||
}
|
||||
@ -100,15 +113,15 @@ void copy_blocks(
|
||||
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
|
||||
}
|
||||
// Create block mapping array.
|
||||
std::vector<int> block_mapping_vec;
|
||||
std::vector<int64_t> block_mapping_vec;
|
||||
for (const auto& pair : block_mapping) {
|
||||
int src_block_number = pair.first;
|
||||
for (int dst_block_number : pair.second) {
|
||||
int64_t src_block_number = pair.first;
|
||||
for (int64_t dst_block_number : pair.second) {
|
||||
block_mapping_vec.push_back(src_block_number);
|
||||
block_mapping_vec.push_back(dst_block_number);
|
||||
}
|
||||
}
|
||||
int* block_mapping_array = block_mapping_vec.data();
|
||||
int64_t* block_mapping_array = block_mapping_vec.data();
|
||||
int num_pairs = block_mapping_vec.size() / 2;
|
||||
|
||||
// Move the data structures to the GPU.
|
||||
@ -118,77 +131,107 @@ void copy_blocks(
|
||||
torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
|
||||
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
|
||||
torch::Tensor block_mapping_tensor = torch::from_blob(
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt).to(cache_device);
|
||||
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
|
||||
|
||||
// Launch the kernel.
|
||||
const int numel_per_block = key_caches[0][0].numel();
|
||||
dim3 grid(num_layers, num_pairs);
|
||||
dim3 block(std::min(1024, numel_per_block));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(cache_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping_tensor.data_ptr<int>(),
|
||||
block_mapping_tensor.data_ptr<int64_t>(),
|
||||
numel_per_block);
|
||||
}));
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int* __restrict__ slot_mapping, // [num_tokens]
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride,
|
||||
const int value_stride,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int block_size,
|
||||
const int x) {
|
||||
const int token_idx = blockIdx.x;
|
||||
const int slot_idx = slot_mapping[token_idx];
|
||||
const int block_idx = slot_idx / block_size;
|
||||
const int block_offset = slot_idx % block_size;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
// Padding token that should be ignored.
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
const int n = num_heads * head_size;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const int src_key_idx = token_idx * key_stride + i;
|
||||
const int src_value_idx = token_idx * value_stride + i;
|
||||
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int x_idx = head_offset / x;
|
||||
const int x_offset = head_offset % x;
|
||||
|
||||
const int tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
|
||||
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
|
||||
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
|
||||
+ head_idx * (head_size / x) * block_size * x
|
||||
+ x_idx * block_size * x
|
||||
+ block_offset * x
|
||||
+ x_offset;
|
||||
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
|
||||
+ head_idx * head_size * block_size
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (is_fp8_e5m2_kv_cache) {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
|
||||
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
} else {
|
||||
key_cache[tgt_key_idx] = tgt_key;
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
||||
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), \
|
||||
key_stride, \
|
||||
value_stride, \
|
||||
num_heads, \
|
||||
head_size, \
|
||||
block_size, \
|
||||
x);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping) // [num_tokens]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype)
|
||||
{
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
@ -201,26 +244,27 @@ void reshape_and_cache(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
value_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int>(),
|
||||
key_stride,
|
||||
value_stride,
|
||||
num_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
x);
|
||||
});
|
||||
if (kv_cache_dtype == "auto") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, float, false);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
|
||||
}
|
||||
} else if (kv_cache_dtype == "fp8_e5m2") {
|
||||
if (key.dtype() == at::ScalarType::Float) {
|
||||
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::Half) {
|
||||
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
|
||||
} else if (key.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@ -264,8 +308,8 @@ __global__ void gather_cached_kv_kernel(
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
||||
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -330,8 +374,8 @@ __global__ void gather_cached_kv_kernel_optimized(
|
||||
src_key_indices[j] = src_key_idx;
|
||||
src_value_indices[j] = src_value_idx;
|
||||
|
||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
||||
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@ -363,10 +407,9 @@ void gather_cached_kv(
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
|
||||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
@ -384,3 +427,55 @@ void gather_cached_kv(
|
||||
x);
|
||||
});
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__global__ void convert_fp8_e5m2_kernel(
|
||||
const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
|
||||
#else
|
||||
assert(false);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \
|
||||
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \
|
||||
block_stride);
|
||||
|
||||
void convert_fp8_e5m2(
|
||||
torch::Tensor& src_cache,
|
||||
torch::Tensor& dst_cache)
|
||||
{
|
||||
int64_t num_blocks = src_cache.size(0);
|
||||
int64_t block_stride = src_cache.stride(0);
|
||||
|
||||
dim3 grid(num_blocks);
|
||||
dim3 block(std::min(block_stride, int64_t(512)));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (src_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, float);
|
||||
} else if (src_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
|
||||
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Float) {
|
||||
CALL_CONVERT_FP8_E5M2(float, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::Half) {
|
||||
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
|
||||
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
|
||||
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
|
||||
}
|
||||
}
|
||||
|
28
csrc/cuda_compat.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_LDG(arg) __ldg(arg)
|
||||
#else
|
||||
#define VLLM_LDG(arg) *(arg)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||
#else
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#else
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
10
csrc/cuda_utils.h
Normal file
@ -0,0 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id);
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id);
|
35
csrc/cuda_utils_kernels.cu
Normal file
@ -0,0 +1,35 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
{
|
||||
int device, value;
|
||||
if (device_id < 0) {
|
||||
cudaGetDevice(&device);
|
||||
}
|
||||
else {
|
||||
device = device_id;
|
||||
}
|
||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
|
||||
return value;
|
||||
}
|
||||
|
||||
|
||||
int get_max_shared_memory_per_block_device_attribute(
|
||||
int device_id)
|
||||
{
|
||||
int attribute;
|
||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
|
||||
#ifdef USE_ROCM
|
||||
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
|
||||
#else
|
||||
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
|
||||
#endif
|
||||
|
||||
return get_device_attribute(attribute, device_id);
|
||||
}
|
148
csrc/custom_all_reduce.cu
Normal file
@ -0,0 +1,148 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "custom_all_reduce.cuh"
|
||||
|
||||
// fake pointer type
|
||||
using fptr_t = uint64_t;
|
||||
static_assert(sizeof(void *) == sizeof(fptr_t));
|
||||
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink) {
|
||||
int world_size = offsets.size();
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (world_size != handles.size())
|
||||
throw std::invalid_argument(
|
||||
"handles length should equal to offsets length");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
|
||||
cudaIpcMemHandle_t ipc_handles[8];
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
|
||||
}
|
||||
return (fptr_t) new vllm::CustomAllreduce(
|
||||
reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
|
||||
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
|
||||
}
|
||||
|
||||
/**
|
||||
* Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
|
||||
* t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
|
||||
* because it allows transpose of contiguous slice (i.e. slicing the first
|
||||
* dimension). Currently, we require this because stride information is not
|
||||
* passed into the kernels and we treat input tensors as flat.
|
||||
*
|
||||
* Examples
|
||||
* A = torch.zeros(3, 3, 3)
|
||||
* 1. A: OK
|
||||
* 2. A[1:]: OK
|
||||
* 3. A.permute(2, 0, 1): OK
|
||||
* 4. A[1:].permute(2, 0, 1): OK
|
||||
* 5. A[None].expand(2, -1, -1, -1): Not OK
|
||||
* 6. A[:, 1:, 1:]: Not OK
|
||||
*/
|
||||
bool _is_weak_contiguous(torch::Tensor &t) {
|
||||
return t.is_contiguous() ||
|
||||
(t.storage().nbytes() - t.storage_offset() * t.element_size() ==
|
||||
t.numel() * t.element_size());
|
||||
}
|
||||
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink) {
|
||||
auto inp_size = inp.numel() * inp.element_size();
|
||||
// custom allreduce requires input byte size to be multiples of 16
|
||||
if (inp_size % 16 != 0) return false;
|
||||
if (!_is_weak_contiguous(inp)) return false;
|
||||
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
|
||||
// 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
|
||||
// <= 512k
|
||||
return world_size <= 4 && inp_size <= 512 * 1024;
|
||||
}
|
||||
|
||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
|
||||
cudaStream_t stream) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
TORCH_CHECK(_is_weak_contiguous(out));
|
||||
switch (out.scalar_type()) {
|
||||
case at::ScalarType::Float: {
|
||||
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
|
||||
reinterpret_cast<float *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Half: {
|
||||
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
|
||||
reinterpret_cast<half *>(out.data_ptr()),
|
||||
out.numel());
|
||||
break;
|
||||
}
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
case at::ScalarType::BFloat16: {
|
||||
fa->allreduce<nv_bfloat16>(
|
||||
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
|
||||
break;
|
||||
}
|
||||
#endif
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports float32, float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
_all_reduce(_fa, inp, out, stream);
|
||||
}
|
||||
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out) {
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto input_size = inp.numel() * inp.element_size();
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
|
||||
"registered buffer is too small to contain the input");
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
|
||||
input_size, cudaMemcpyDeviceToDevice, stream));
|
||||
_all_reduce(_fa, reg_buffer, out, stream);
|
||||
}
|
||||
|
||||
void dispose(fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
delete fa;
|
||||
}
|
||||
|
||||
int meta_size() { return sizeof(vllm::Metadata); }
|
||||
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_buffer(handles, offsets, t.data_ptr());
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
return fa->get_graph_buffer_ipc_meta();
|
||||
}
|
||||
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
|
||||
fa->register_graph_buffers(handles, offsets);
|
||||
}
|
562
csrc/custom_all_reduce.cuh
Normal file
@ -0,0 +1,562 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#define CUDACHECK(cmd) \
|
||||
do { \
|
||||
cudaError_t e = cmd; \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace vllm {
|
||||
|
||||
struct Signal {
|
||||
alignas(64) union {
|
||||
uint64_t flag;
|
||||
unsigned char data[8];
|
||||
} start;
|
||||
alignas(64) union {
|
||||
uint64_t flag;
|
||||
unsigned char data[8];
|
||||
} end;
|
||||
};
|
||||
|
||||
struct Metadata {
|
||||
alignas(128) Signal sg;
|
||||
alignas(128) int counter;
|
||||
};
|
||||
static_assert(offsetof(Metadata, counter) == 128);
|
||||
static_assert(sizeof(Metadata) == 256);
|
||||
|
||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
|
||||
|
||||
struct RankSignals {
|
||||
volatile Signal *signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
struct __align__(alignof(T) * sz) array_t {
|
||||
T data[sz];
|
||||
using type = T;
|
||||
static constexpr int size = sz;
|
||||
};
|
||||
|
||||
// use packed type to maximize memory efficiency
|
||||
// goal: generate ld.128 and st.128 instructions
|
||||
template <typename T>
|
||||
struct packed_t {
|
||||
// the (P)acked type for load/store
|
||||
using P = array_t<T, 16 / sizeof(T)>;
|
||||
// the (A)ccumulator type for reduction
|
||||
using A = array_t<float, 16 / sizeof(T)>;
|
||||
};
|
||||
|
||||
#define DINLINE __device__ __forceinline__
|
||||
|
||||
// scalar cast functions
|
||||
DINLINE float upcast_s(half val) { return __half2float(val); }
|
||||
|
||||
template <typename T>
|
||||
DINLINE T downcast_s(float val);
|
||||
template <>
|
||||
DINLINE half downcast_s(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
// scalar add functions
|
||||
// for some reason when compiling with Pytorch, the + operator for half and
|
||||
// bfloat is disabled so we call the intrinsics directly
|
||||
DINLINE half &assign_add(half &a, half b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
|
||||
|
||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
|
||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
|
||||
template <>
|
||||
DINLINE nv_bfloat16 downcast_s(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
|
||||
a = __hadd(a, b);
|
||||
return a;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
assign_add(a.data[i], b.data[i]);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, int N>
|
||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
array_t<float, N> out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < N; i++) {
|
||||
out.data[i] = upcast_s(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename O>
|
||||
DINLINE O downcast(array_t<float, O::size> val) {
|
||||
if constexpr (std::is_same<typename O::type, float>::value) {
|
||||
return val;
|
||||
} else {
|
||||
O out;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < O::size; i++) {
|
||||
out.data[i] = downcast_s<typename O::type>(val.data[i]);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
// compute flag at compile time
|
||||
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
|
||||
auto m = std::numeric_limits<uint64_t>::max();
|
||||
return m >> ((8 - ngpus) * 8);
|
||||
}
|
||||
|
||||
template <int ngpus>
|
||||
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||
int rank) {
|
||||
constexpr auto FLAG = compute_flag(ngpus);
|
||||
if (blockIdx.x == 0) {
|
||||
if (threadIdx.x < ngpus)
|
||||
// simultaneously write to the corresponding byte to all other ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->start.data[rank] = 255;
|
||||
else if (threadIdx.x == 32)
|
||||
// reset
|
||||
meta->sg.end.flag = 0;
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.start.flag != FLAG)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <int ngpus, bool final_sync = false>
|
||||
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
|
||||
int rank) {
|
||||
constexpr auto FLAG = compute_flag(ngpus);
|
||||
__syncthreads();
|
||||
__shared__ int num;
|
||||
if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
|
||||
__syncthreads();
|
||||
|
||||
// Only the last completing block can perform the end synchronization
|
||||
// This can ensures when the final busy wait ends, all ranks must have
|
||||
// finished reading each other's buffer.
|
||||
if (num == gridDim.x - 1) {
|
||||
if (threadIdx.x == 32) {
|
||||
// reset in a different warp
|
||||
meta->counter = 0;
|
||||
meta->sg.start.flag = 0;
|
||||
} else if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding byte to all other ranks.
|
||||
// Latency = 1 p2p write
|
||||
sg.signals[threadIdx.x]->end.data[rank] = 255;
|
||||
}
|
||||
// if this is the final sync, only one block needs it
|
||||
// because kernel exit can serve as sync
|
||||
if constexpr (final_sync) {
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.end.flag != FLAG)
|
||||
;
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr (!final_sync) {
|
||||
if (threadIdx.x == 0) {
|
||||
while (meta->sg.end.flag != FLAG)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename P, int ngpus, typename A>
|
||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
|
||||
A tmp = upcast(ptrs[0][idx]);
|
||||
#pragma unroll
|
||||
for (int i = 1; i < ngpus; i++) {
|
||||
packed_assign_add(tmp, upcast(ptrs[i][idx]));
|
||||
}
|
||||
return downcast<P>(tmp);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
// note: we don't reorder the address so the accumulation order is the same
|
||||
// for all ranks, ensuring bitwise identical results
|
||||
auto dp = *_dp;
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
// do the actual reduction
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
((P *)result)[idx] =
|
||||
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
|
||||
}
|
||||
end_sync<ngpus, true>(sg, meta, rank);
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
|
||||
return (P *)(((Metadata *)sg) + 1);
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta, T *__restrict__ result,
|
||||
int rank, int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
int part = size / ngpus;
|
||||
int start = rank * part;
|
||||
int end = rank == ngpus - 1 ? size : start + part;
|
||||
const P *ptrs[ngpus];
|
||||
P *tmps[ngpus];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int target = (rank + i) % ngpus;
|
||||
ptrs[i] = (const P *)_dp->ptrs[target];
|
||||
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
|
||||
}
|
||||
auto tmp_out = tmps[0];
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
// stage 1: reduce scatter
|
||||
for (int idx = start + tid; idx < end; idx += stride) {
|
||||
tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
|
||||
}
|
||||
// Maybe TODO: replace this with per-block release-acquire
|
||||
// can save about 1-2us (not a lot though)
|
||||
end_sync<ngpus>(sg, meta, rank);
|
||||
|
||||
// stage 2: allgather
|
||||
for (int idx = tid; idx < part; idx += stride) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ngpus; i++) {
|
||||
int dst_idx = ((rank + i) % ngpus) * part + idx;
|
||||
((P *)result)[dst_idx] = tmps[i][idx];
|
||||
}
|
||||
}
|
||||
// process the last larger partition
|
||||
int remaining = size - part * ngpus;
|
||||
if (tid < remaining) {
|
||||
int dst_idx = tid + part * ngpus;
|
||||
((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
|
||||
}
|
||||
|
||||
// faster than this
|
||||
// for (int idx = tid; idx < size; idx += stride) {
|
||||
// int target_rank = idx / part;
|
||||
// if (target_rank == ngpus) target_rank -= 1;
|
||||
// ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
|
||||
// }
|
||||
}
|
||||
|
||||
template <typename T, int ngpus>
|
||||
__global__ void __launch_bounds__(512, 1)
|
||||
cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
|
||||
volatile Metadata *meta,
|
||||
T *__restrict__ result, int rank,
|
||||
int size) {
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = gridDim.x * blockDim.x;
|
||||
using P = typename packed_t<T>::P;
|
||||
using A = typename packed_t<T>::A;
|
||||
auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
|
||||
constexpr int hg = ngpus / 2;
|
||||
// Actually not quite half butterfly.
|
||||
// This is an all-to-all within each group containing half of the ranks
|
||||
// followed by cross-group add. Equivalent to half butterfly when there
|
||||
// are 4 GPUs, a common case for PCIe cards like T4 and A10.
|
||||
const P *ptrs[hg];
|
||||
{
|
||||
int start = rank - rank % hg;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < hg; i++) {
|
||||
ptrs[i] = (const P *)_dp->ptrs[i + start];
|
||||
}
|
||||
}
|
||||
start_sync<ngpus>(sg, meta, rank);
|
||||
for (int idx = tid; idx < size; idx += stride) {
|
||||
tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
|
||||
}
|
||||
end_sync<ngpus>(sg, meta, rank);
|
||||
|
||||
auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
|
||||
// do the cross group reduction
|
||||
for (int idx = tid; idx < size; idx += stride) {
|
||||
auto tmp = tmp_out[idx];
|
||||
packed_assign_add(tmp, src[idx]);
|
||||
((P *)result)[idx] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
|
||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
|
||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
|
||||
|
||||
class CustomAllreduce {
|
||||
public:
|
||||
int rank_;
|
||||
int world_size_;
|
||||
bool full_nvlink_;
|
||||
|
||||
// below are device pointers
|
||||
RankSignals sg_;
|
||||
std::unordered_map<void *, RankData *> buffers_;
|
||||
Metadata *meta_;
|
||||
|
||||
// stores the registered device pointers from all ranks
|
||||
RankData *d_rank_data_base_, *d_rank_data_end_;
|
||||
std::vector<void *> graph_unreg_buffers_;
|
||||
// a map from IPC handles to opened IPC pointers
|
||||
std::map<IPC_KEY, char *> ipc_handles_;
|
||||
|
||||
/**
|
||||
* meta is a pointer to device metadata and temporary buffer for allreduce.
|
||||
*
|
||||
* There's a total of sizeof(Metadata) of prefix before the actual data,
|
||||
* so meta + 1 points to actual temporary buffer.
|
||||
*
|
||||
* note: this class does not own any device memory. Any required buffers
|
||||
* are passed in from the constructor
|
||||
*/
|
||||
CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
|
||||
const cudaIpcMemHandle_t *handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink = true)
|
||||
: rank_(rank),
|
||||
world_size_(offsets.size()),
|
||||
full_nvlink_(full_nvlink),
|
||||
meta_(meta),
|
||||
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
|
||||
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
Metadata *rank_meta;
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(&handles[i]);
|
||||
handle += offsets[i];
|
||||
rank_meta = (Metadata *)handle;
|
||||
} else {
|
||||
rank_meta = meta_;
|
||||
}
|
||||
sg_.signals[i] = &rank_meta->sg;
|
||||
}
|
||||
}
|
||||
|
||||
char *open_ipc_handle(const void *ipc_handle) {
|
||||
auto [it, new_handle] =
|
||||
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
|
||||
if (new_handle) {
|
||||
char *ipc_ptr;
|
||||
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
|
||||
*((const cudaIpcMemHandle_t *)ipc_handle),
|
||||
cudaIpcMemLazyEnablePeerAccess));
|
||||
it->second = ipc_ptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>>
|
||||
get_graph_buffer_ipc_meta() {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
auto handle_sz = sizeof(cudaIpcMemHandle_t);
|
||||
std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
|
||||
std::vector<int64_t> offsets(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto ptr = graph_unreg_buffers_[i];
|
||||
void *base_ptr;
|
||||
// note: must share the base address of each allocation, or we get wrong
|
||||
// address
|
||||
if (cuPointerGetAttribute(&base_ptr,
|
||||
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
|
||||
(CUdeviceptr)ptr) != CUDA_SUCCESS)
|
||||
throw std::runtime_error("failed to get pointer attr");
|
||||
CUDACHECK(cudaIpcGetMemHandle(
|
||||
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
|
||||
offsets[i] = ((char *)ptr) - ((char *)base_ptr);
|
||||
}
|
||||
return std::make_pair(handles, offsets);
|
||||
}
|
||||
|
||||
void check_rank_data_capacity(size_t num = 1) {
|
||||
if (d_rank_data_base_ + num > d_rank_data_end_)
|
||||
throw std::runtime_error(
|
||||
"Rank data buffer is overflowed by " +
|
||||
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
|
||||
}
|
||||
|
||||
void register_buffer(const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, void *self) {
|
||||
check_rank_data_capacity();
|
||||
RankData data;
|
||||
for (int i = 0; i < world_size_; i++) {
|
||||
if (i != rank_) {
|
||||
char *handle = open_ipc_handle(handles[i].data());
|
||||
handle += offsets[i];
|
||||
data.ptrs[i] = handle;
|
||||
} else {
|
||||
data.ptrs[i] = self;
|
||||
}
|
||||
}
|
||||
auto d_data = d_rank_data_base_++;
|
||||
CUDACHECK(
|
||||
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
|
||||
buffers_[self] = d_data;
|
||||
}
|
||||
|
||||
// note: when registering graph buffers, we intentionally choose to not
|
||||
// deduplicate the addresses. That means if the allocator reuses some
|
||||
// addresses, they will be registered again. This is to account for the remote
|
||||
// possibility of different allocation patterns between ranks. For example,
|
||||
// rank 1 may get the same input address for the second allreduce, but rank 2
|
||||
// got a different address. IPC handles have internal reference counting
|
||||
// mechanism so overhead should be small.
|
||||
void register_graph_buffers(
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets) {
|
||||
auto num_buffers = graph_unreg_buffers_.size();
|
||||
check_rank_data_capacity(num_buffers);
|
||||
std::vector<RankData> rank_data(num_buffers);
|
||||
for (int i = 0; i < num_buffers; i++) {
|
||||
auto self_ptr = graph_unreg_buffers_[i];
|
||||
auto &rd = rank_data[i];
|
||||
for (int j = 0; j < world_size_; j++) {
|
||||
if (j != rank_) {
|
||||
char *handle =
|
||||
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
|
||||
handle += offsets[j][i];
|
||||
rd.ptrs[j] = handle;
|
||||
} else {
|
||||
rd.ptrs[j] = self_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
|
||||
sizeof(RankData) * num_buffers,
|
||||
cudaMemcpyHostToDevice));
|
||||
d_rank_data_base_ += num_buffers;
|
||||
graph_unreg_buffers_.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* This is the result after careful grid search. Using 36 blocks give the best
|
||||
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
|
||||
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
|
||||
* Not quite sure the underlying reason, but my guess is that too many SMs
|
||||
* will cause contention on NVLink bus.
|
||||
*/
|
||||
template <typename T>
|
||||
void allreduce(cudaStream_t stream, T *input, T *output, int size,
|
||||
int threads = 512, int block_limit = 36) {
|
||||
auto d = packed_t<T>::P::size;
|
||||
if (size % d != 0)
|
||||
throw std::runtime_error(
|
||||
"custom allreduce currently requires input length to be multiple "
|
||||
"of " +
|
||||
std::to_string(d));
|
||||
|
||||
RankData *ptrs;
|
||||
cudaStreamCaptureStatus status;
|
||||
CUDACHECK(cudaStreamIsCapturing(stream, &status));
|
||||
if (status == cudaStreamCaptureStatusActive) {
|
||||
ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
|
||||
graph_unreg_buffers_.push_back(input);
|
||||
} else {
|
||||
auto it = buffers_.find(input);
|
||||
if (it == buffers_.end())
|
||||
throw std::runtime_error(
|
||||
"buffer address " +
|
||||
std::to_string(reinterpret_cast<uint64_t>(input)) +
|
||||
" is not registered!");
|
||||
ptrs = it->second;
|
||||
}
|
||||
|
||||
size /= d;
|
||||
auto bytes = size * sizeof(typename packed_t<T>::P);
|
||||
int blocks = std::min(block_limit, (size + threads - 1) / threads);
|
||||
#define KL(ngpus, name) \
|
||||
name<T, ngpus> \
|
||||
<<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
|
||||
#define REDUCE_CASE(ngpus) \
|
||||
case ngpus: { \
|
||||
if (world_size_ == 2) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else if (full_nvlink_) { \
|
||||
if ((world_size_ <= 4 && bytes < 512 * 1024) || \
|
||||
(world_size_ <= 8 && bytes < 256 * 1024)) { \
|
||||
KL(ngpus, cross_device_reduce_1stage); \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_2stage); \
|
||||
} \
|
||||
} else { \
|
||||
KL(ngpus, cross_device_reduce_half_butterfly); \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (world_size_) {
|
||||
REDUCE_CASE(2)
|
||||
REDUCE_CASE(4)
|
||||
REDUCE_CASE(6)
|
||||
REDUCE_CASE(8)
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"custom allreduce only supports num gpus in (2,4,6,8). Actual num "
|
||||
"gpus = " +
|
||||
std::to_string(world_size_));
|
||||
}
|
||||
#undef REDUCE_CASE
|
||||
#undef KL
|
||||
}
|
||||
|
||||
~CustomAllreduce() {
|
||||
for (auto [_, ptr] : ipc_handles_) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
}
|
||||
}
|
||||
};
|
||||
/**
|
||||
* To inspect PTX/SASS, copy paste this header file to compiler explorer and add
|
||||
a template instantiation:
|
||||
* template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
|
||||
int, int, int);
|
||||
*/
|
||||
} // namespace vllm
|
284
csrc/custom_all_reduce_test.cu
Normal file
@ -0,0 +1,284 @@
|
||||
/**
|
||||
* This is a standalone test for custom allreduce.
|
||||
* To compile, make sure you have MPI and NCCL installed in your system.
|
||||
* export MPI_HOME=XXX
|
||||
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
|
||||
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
|
||||
*
|
||||
* Warning: this C++ test is not designed to be very readable and was used
|
||||
* during the rapid prototyping process.
|
||||
*
|
||||
* To run:
|
||||
* mpirun -np 8 ./custom_all_reduce_test
|
||||
*/
|
||||
#include <cuda.h>
|
||||
#include <curand_kernel.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "cuda_profiler_api.h"
|
||||
#include "custom_all_reduce.cuh"
|
||||
#include "mpi.h"
|
||||
#include "nccl.h"
|
||||
|
||||
#define MPICHECK(cmd) \
|
||||
do { \
|
||||
int e = cmd; \
|
||||
if (e != MPI_SUCCESS) { \
|
||||
printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define NCCLCHECK(cmd) \
|
||||
do { \
|
||||
ncclResult_t r = cmd; \
|
||||
if (r != ncclSuccess) { \
|
||||
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||
ncclGetErrorString(r)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
__global__ void dummy_kernel() {
|
||||
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void set_data(T *data, int size, int myRank) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
data[idx] = myRank * 0.11f;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
|
||||
double *fdata2, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
fdata1[idx] = data1[idx];
|
||||
fdata2[idx] = data2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
|
||||
int myRank, int nRanks, int size) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
|
||||
idx += gridDim.x * blockDim.x) {
|
||||
double sum = 0.0;
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
|
||||
T hval = val; // downcast first
|
||||
sum += static_cast<double>(hval);
|
||||
if (i == myRank) data[idx] = hval;
|
||||
}
|
||||
ground_truth[idx] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
|
||||
int data_size) {
|
||||
T *result;
|
||||
cudaStream_t stream;
|
||||
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
|
||||
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
|
||||
CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
|
||||
|
||||
cudaIpcMemHandle_t self_data_handle;
|
||||
cudaIpcMemHandle_t data_handles[8];
|
||||
vllm::Metadata *buffer;
|
||||
T *self_data_copy;
|
||||
/**
|
||||
* Allocate IPC buffer
|
||||
*
|
||||
* The first section is a temporary buffer for storing intermediate allreduce
|
||||
* results, if a particular algorithm requires it. The second section is for
|
||||
* the input to the allreduce. The actual API takes the input pointer as an
|
||||
* argument (that is, they can and usually should be allocated separately).
|
||||
* But since the input pointers and the temporary buffer all require IPC
|
||||
* registration, they are allocated and registered together in the test for
|
||||
* convenience.
|
||||
*/
|
||||
CUDACHECK(
|
||||
cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||
CUDACHECK(cudaMemset(buffer, 0,
|
||||
2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
|
||||
CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
|
||||
CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
|
||||
|
||||
MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
|
||||
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
|
||||
MPI_BYTE, MPI_COMM_WORLD));
|
||||
|
||||
void *rank_data;
|
||||
size_t rank_data_sz = 16 * 1024 * 1024;
|
||||
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
|
||||
std::vector<int64_t> offsets(nRanks, 0);
|
||||
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
|
||||
offsets, myRank);
|
||||
auto *self_data =
|
||||
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
|
||||
sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||
// hack buffer registration
|
||||
{
|
||||
std::vector<std::string> handles;
|
||||
handles.reserve(nRanks);
|
||||
for (int i = 0; i < nRanks; i++) {
|
||||
char *begin = (char *)&data_handles[i];
|
||||
char *end = (char *)&data_handles[i + 1];
|
||||
handles.emplace_back(begin, end);
|
||||
}
|
||||
std::vector<int64_t> offsets(
|
||||
nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
|
||||
fa.register_buffer(handles, offsets, self_data);
|
||||
}
|
||||
|
||||
double *ground_truth;
|
||||
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
|
||||
curandState_t *states;
|
||||
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
|
||||
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
|
||||
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
|
||||
nRanks, data_size);
|
||||
CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
cudaEvent_t start, stop;
|
||||
CUDACHECK(cudaEventCreate(&start));
|
||||
CUDACHECK(cudaEventCreate(&stop));
|
||||
|
||||
ncclDataType_t ncclDtype;
|
||||
if (std::is_same<T, half>::value) {
|
||||
ncclDtype = ncclFloat16;
|
||||
} else if (std::is_same<T, nv_bfloat16>::value) {
|
||||
ncclDtype = ncclBfloat16;
|
||||
} else {
|
||||
ncclDtype = ncclFloat;
|
||||
}
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
constexpr int warmup_iters = 5;
|
||||
constexpr int num_iters = 25;
|
||||
// warmup
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||
stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
|
||||
stream));
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
float allreduce_ms = 0;
|
||||
cudaEventElapsedTime(&allreduce_ms, start, stop);
|
||||
|
||||
// if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
|
||||
|
||||
dummy_kernel<<<1, 1, 0, stream>>>();
|
||||
// warm up
|
||||
for (int i = 0; i < warmup_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(start, stream));
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
|
||||
}
|
||||
CUDACHECK(cudaEventRecord(stop, stream));
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
float duration_ms = 0;
|
||||
cudaEventElapsedTime(&duration_ms, start, stop);
|
||||
if (myRank == 0)
|
||||
printf(
|
||||
"Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
|
||||
"time:%.2fus\n",
|
||||
myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
|
||||
duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
|
||||
|
||||
// And wait for all the queued up work to complete
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
|
||||
ncclSum, comm, stream));
|
||||
|
||||
double *nccl_result, *my_result;
|
||||
CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
|
||||
CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
|
||||
|
||||
convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
|
||||
my_result, data_size);
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
|
||||
for (unsigned long j = 0; j < data_size; j++) {
|
||||
auto diff = abs(nccl_result[j] - my_result[j]);
|
||||
if (diff >= 1e-2) {
|
||||
printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
|
||||
myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
long double nccl_diffs = 0.0;
|
||||
long double my_diffs = 0.0;
|
||||
for (int j = 0; j < data_size; j++) {
|
||||
nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
|
||||
my_diffs += abs(my_result[j] - ground_truth[j]);
|
||||
}
|
||||
if (myRank == 0)
|
||||
std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
|
||||
<< " me: " << my_diffs / data_size << std::endl;
|
||||
|
||||
CUDACHECK(cudaFree(result));
|
||||
CUDACHECK(cudaFree(self_data_copy));
|
||||
CUDACHECK(cudaFree(rank_data));
|
||||
CUDACHECK(cudaFree(buffer));
|
||||
CUDACHECK(cudaFree(states));
|
||||
CUDACHECK(cudaFreeHost(ground_truth));
|
||||
CUDACHECK(cudaFreeHost(nccl_result));
|
||||
CUDACHECK(cudaFreeHost(my_result));
|
||||
CUDACHECK(cudaStreamDestroy(stream));
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
int nRanks, myRank;
|
||||
MPICHECK(MPI_Init(&argc, &argv));
|
||||
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
|
||||
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
|
||||
CUDACHECK(cudaSetDevice(myRank));
|
||||
ncclUniqueId id;
|
||||
ncclComm_t comm;
|
||||
if (myRank == 0) ncclGetUniqueId(&id);
|
||||
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
|
||||
MPI_COMM_WORLD));
|
||||
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
|
||||
|
||||
cudaProfilerStart();
|
||||
// for (int threads : {256, 512}) {
|
||||
// for (int block_limit = 16; block_limit < 112; block_limit += 4) {
|
||||
// run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
|
||||
// }
|
||||
// }
|
||||
for (int sz = 512; sz <= (32 << 20); sz *= 2) {
|
||||
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
|
||||
}
|
||||
|
||||
cudaProfilerStop();
|
||||
return EXIT_SUCCESS;
|
||||
}
|
37
csrc/dispatch_utils.h
Normal file
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* Adapted from
|
||||
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
|
||||
|
||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
|
@ -1,14 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
@ -8,8 +10,8 @@ namespace vllm {
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
|
||||
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
@ -33,22 +35,51 @@ __global__ void rms_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
__global__ void fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) input[blockIdx.x * hidden_size + idx];
|
||||
x += (float) residual[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
|
||||
}
|
||||
variance = blockReduceSum<float>(variance);
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float) residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out, // [num_tokens, hidden_size]
|
||||
torch::Tensor& input, // [num_tokens, hidden_size]
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int num_tokens = input.size(0);
|
||||
int hidden_size = input.size(1);
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
@ -61,3 +92,29 @@ void rms_norm(
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
float epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"fused_add_rms_norm_kernel",
|
||||
[&] {
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
epsilon,
|
||||
num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
7
csrc/moe/moe_ops.cpp
Normal file
@ -0,0 +1,7 @@
|
||||
#include "moe_ops.h"
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
|
||||
}
|
9
csrc/moe/moe_ops.h
Normal file
@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights,
|
||||
torch::Tensor& topk_indices,
|
||||
torch::Tensor& token_expert_indices,
|
||||
torch::Tensor& gating_output);
|
499
csrc/moe/topk_softmax_kernels.cu
Normal file
@ -0,0 +1,499 @@
|
||||
/*
|
||||
* Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
|
||||
* Copyright (c) 2024, The vLLM team.
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
|
||||
namespace vllm {
|
||||
namespace moe {
|
||||
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
|
||||
/// Aligned array type
|
||||
template <
|
||||
typename T,
|
||||
/// Number of elements in the array
|
||||
int N,
|
||||
/// Alignment requirement in bytes
|
||||
int Alignment = sizeof(T) * N
|
||||
>
|
||||
class alignas(Alignment) AlignedArray {
|
||||
float data[N];
|
||||
};
|
||||
|
||||
// ====================== Softmax things ===============================
|
||||
// We have our own implementation of softmax here so we can support transposing the output
|
||||
// in the softmax kernel when we extend this module to support expert-choice routing.
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__
|
||||
void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
|
||||
{
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
__shared__ float normalizing_factor;
|
||||
__shared__ float float_max;
|
||||
|
||||
const int thread_row_offset = blockIdx.x * num_cols;
|
||||
|
||||
cub::Sum sum;
|
||||
float threadData(-FLT_MAX);
|
||||
|
||||
// Don't touch finished rows.
|
||||
if ((finished != nullptr) && finished[blockIdx.x])
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData = max(static_cast<float>(input[idx]), threadData);
|
||||
}
|
||||
|
||||
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
float_max = maxElem;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
threadData = 0;
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
threadData += exp((static_cast<float>(input[idx]) - float_max));
|
||||
}
|
||||
|
||||
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
normalizing_factor = 1.f / Z;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
|
||||
{
|
||||
const int idx = thread_row_offset + ii;
|
||||
const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
|
||||
output[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <int TPB>
|
||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
|
||||
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, float>;
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmpStorage;
|
||||
|
||||
cub_kvp thread_kvp;
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
const int num_rows = gridDim.x;
|
||||
const int block_row = blockIdx.x;
|
||||
|
||||
const bool row_is_active = finished ? !finished[block_row] : true;
|
||||
const int thread_read_offset = blockIdx.x * num_experts;
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
thread_kvp.key = 0;
|
||||
thread_kvp.value = -1.f; // This is OK because inputs are probabilities
|
||||
|
||||
cub_kvp inp_kvp;
|
||||
for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
|
||||
{
|
||||
const int idx = thread_read_offset + expert;
|
||||
inp_kvp.key = expert;
|
||||
inp_kvp.value = inputs_after_softmax[idx];
|
||||
|
||||
for (int prior_k = 0; prior_k < k_idx; ++prior_k)
|
||||
{
|
||||
const int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
if (prior_winning_expert == expert)
|
||||
{
|
||||
inp_kvp = thread_kvp;
|
||||
}
|
||||
}
|
||||
|
||||
thread_kvp = arg_max(inp_kvp, thread_kvp);
|
||||
}
|
||||
|
||||
const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// Ignore experts the node isn't responsible for with expert parallelism
|
||||
const int expert = result_kvp.key;
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
const int idx = k * block_row + k_idx;
|
||||
output[idx] = result_kvp.value;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
|
||||
assert(indices[idx] >= 0);
|
||||
source_rows[idx] = k_idx * num_rows + block_row;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== TopK softmax things ===============================
|
||||
|
||||
/*
|
||||
A Top-K gating softmax written to exploit when the number of experts in the MoE layers
|
||||
are a small power of 2. This allows us to cleanly share the rows among the threads in
|
||||
a single warp and eliminate communication between warps (so no need to use shared mem).
|
||||
|
||||
It fuses the softmax, max and argmax into a single kernel.
|
||||
|
||||
Limitations:
|
||||
1) This implementation is intended for when the number of experts is a small power of 2.
|
||||
2) This implementation assumes k is small, but will work for any k.
|
||||
*/
|
||||
|
||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
|
||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
|
||||
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
|
||||
int* source_rows, const int k, const int start_expert, const int end_expert)
|
||||
{
|
||||
// We begin by enforcing compile time assertions and setting up compile time constants.
|
||||
static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
|
||||
static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
|
||||
static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
|
||||
static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
|
||||
|
||||
// Number of bytes each thread pulls in per load
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
|
||||
static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
|
||||
static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
|
||||
|
||||
// Restrictions based on previous section.
|
||||
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
|
||||
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
|
||||
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
|
||||
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
|
||||
|
||||
// We have NUM_EXPERTS elements per row. We specialize for small #experts
|
||||
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
|
||||
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
|
||||
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
|
||||
|
||||
// Restrictions for previous section.
|
||||
static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
|
||||
|
||||
// ===================== From this point, we finally start computing run-time variables. ========================
|
||||
|
||||
// Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
|
||||
// This, each block processes a chunk of rows. We start by computing the start row for each block.
|
||||
const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
|
||||
|
||||
// Now, using the base row per thread block, we compute the base row per warp.
|
||||
const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
|
||||
|
||||
// The threads in a warp are split into sub-groups that will work on a row.
|
||||
// We compute row offset for each thread sub-group
|
||||
const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
|
||||
const int thread_row = warp_base_row + thread_row_in_warp;
|
||||
|
||||
// Threads with indices out of bounds should early exit here.
|
||||
if (thread_row >= num_rows)
|
||||
{
|
||||
return;
|
||||
}
|
||||
const bool row_is_active = finished ? !finished[thread_row] : true;
|
||||
|
||||
// We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
|
||||
// row it will read.
|
||||
const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
|
||||
|
||||
// Now, we compute the group each thread belong to in order to determine the first column to start loads.
|
||||
const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
|
||||
const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
|
||||
const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
|
||||
|
||||
// Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
|
||||
// this can support all powers of 2 up to 16.
|
||||
// NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
|
||||
// We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
|
||||
using AccessType = AlignedArray<float, ELTS_PER_LDG>;
|
||||
|
||||
// Finally, we pull in the data from global mem
|
||||
float row_chunk[VPT];
|
||||
AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
|
||||
const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
|
||||
{
|
||||
row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
|
||||
}
|
||||
|
||||
// First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
|
||||
// convert to float afterwards for the exp + sum reduction.
|
||||
float thread_max = row_chunk[0];
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < VPT; ++ii)
|
||||
{
|
||||
thread_max = max(thread_max, row_chunk[ii]);
|
||||
}
|
||||
|
||||
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
|
||||
}
|
||||
|
||||
// From this point, thread max in all the threads have the max within the row.
|
||||
// Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
|
||||
float row_sum = 0;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = expf(row_chunk[ii] - thread_max);
|
||||
row_sum += row_chunk[ii];
|
||||
}
|
||||
|
||||
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
|
||||
}
|
||||
|
||||
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
|
||||
// respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
|
||||
// compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
|
||||
// However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
|
||||
// argmax after computing the softmax.
|
||||
const float reciprocal_row_sum = 1.f / row_sum;
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < VPT; ++ii)
|
||||
{
|
||||
row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
|
||||
}
|
||||
|
||||
// Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
|
||||
// with the max index.
|
||||
int start_col = first_elt_read_by_thread;
|
||||
static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
|
||||
|
||||
for (int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
// First, each thread does the local argmax
|
||||
float max_val = row_chunk[0];
|
||||
int expert = start_col;
|
||||
#pragma unroll
|
||||
for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
|
||||
{
|
||||
float val = row_chunk[ldg * ELTS_PER_LDG + ii];
|
||||
|
||||
// No check on the experts here since columns with the smallest index are processed first and only
|
||||
// updated if > (not >=)
|
||||
if (val > max_val)
|
||||
{
|
||||
max_val = val;
|
||||
expert = col + ii;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
|
||||
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
|
||||
// then blank out their max with -inf and the warp can run more iterations...
|
||||
#pragma unroll
|
||||
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
|
||||
{
|
||||
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
|
||||
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
|
||||
|
||||
// We want lower indices to "win" in every thread so we break ties this way
|
||||
if (other_max > max_val || (other_max == max_val && other_expert < expert))
|
||||
{
|
||||
max_val = other_max;
|
||||
expert = other_expert;
|
||||
}
|
||||
}
|
||||
|
||||
// Write the max for this k iteration to global memory.
|
||||
if (thread_group_idx == 0)
|
||||
{
|
||||
// Add a guard to ignore experts not included by this node
|
||||
const bool node_uses_expert = expert >= start_expert && expert < end_expert;
|
||||
const bool should_process_row = row_is_active && node_uses_expert;
|
||||
|
||||
// The lead thread from each sub-group will write out the final results to global memory. (This will be a
|
||||
// single) thread per row of the input/output matrices.
|
||||
const int idx = k * thread_row + k_idx;
|
||||
output[idx] = max_val;
|
||||
indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
|
||||
source_rows[idx] = k_idx * num_rows + thread_row;
|
||||
}
|
||||
|
||||
// Finally, we clear the value in the thread with the current max if there is another iteration to run.
|
||||
if (k_idx + 1 < k)
|
||||
{
|
||||
const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
|
||||
const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
|
||||
|
||||
// Only the thread in the group which produced the max will reset the "winning" value to -inf.
|
||||
if (thread_group_idx == thread_to_clear_in_group)
|
||||
{
|
||||
const int offset_for_expert = expert % ELTS_PER_LDG;
|
||||
// Safe to set to any negative value since row_chunk values must be between 0 and 1.
|
||||
row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail
|
||||
{
|
||||
// Constructs some constants needed to partition the work across threads at compile time.
|
||||
template <int EXPERTS, int BYTES_PER_LDG>
|
||||
struct TopkConstants
|
||||
{
|
||||
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
|
||||
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
|
||||
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
|
||||
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
|
||||
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
|
||||
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <int EXPERTS, int WARPS_PER_TB>
|
||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
|
||||
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
|
||||
{
|
||||
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
|
||||
|
||||
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
|
||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
|
||||
static constexpr int VPT = Constants::VPT;
|
||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||
|
||||
dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
|
||||
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
|
||||
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
|
||||
}
|
||||
|
||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
|
||||
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
|
||||
gating_output, nullptr, topk_weights, topk_indicies, \
|
||||
token_expert_indices, num_tokens, topk, 0, num_experts, \
|
||||
stream);
|
||||
|
||||
void topkGatingSoftmaxKernelLauncher(
|
||||
const float* gating_output,
|
||||
float* topk_weights,
|
||||
int* topk_indicies,
|
||||
int* token_expert_indices,
|
||||
float* softmax_workspace,
|
||||
const int num_tokens,
|
||||
const int num_experts,
|
||||
const int topk,
|
||||
cudaStream_t stream) {
|
||||
static constexpr int WARPS_PER_TB = 4;
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_SOFTMAX(2, WARPS_PER_TB);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_SOFTMAX(4, WARPS_PER_TB);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_SOFTMAX(8, WARPS_PER_TB);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_SOFTMAX(16, WARPS_PER_TB);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_SOFTMAX(32, WARPS_PER_TB);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_SOFTMAX(64, WARPS_PER_TB);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_SOFTMAX(128, WARPS_PER_TB);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_SOFTMAX(256, WARPS_PER_TB);
|
||||
break;
|
||||
default: {
|
||||
TORCH_CHECK(softmax_workspace != nullptr,
|
||||
"softmax_workspace must be provided for num_experts that are not a power of 2.");
|
||||
static constexpr int TPB = 256;
|
||||
moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
gating_output, nullptr, softmax_workspace, num_experts);
|
||||
moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
|
||||
softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
|
||||
num_experts, topk, 0, num_experts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace moe
|
||||
} // namespace vllm
|
||||
|
||||
void topk_softmax(
|
||||
torch::Tensor& topk_weights, // [num_tokens, topk]
|
||||
torch::Tensor& topk_indices, // [num_tokens, topk]
|
||||
torch::Tensor& token_expert_indices, // [num_tokens, topk]
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
const bool needs_workspace = !is_pow_2 || num_experts > 256;
|
||||
const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
|
||||
vllm::moe::topkGatingSoftmaxKernelLauncher(
|
||||
gating_output.data_ptr<float>(),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
token_expert_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
num_tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
stream);
|
||||
}
|
108
csrc/moe_align_block_size_kernels.cu
Normal file
@ -0,0 +1,108 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
const static size_t NUM_MAX_EXPERTS = 64;
|
||||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
|
||||
|
||||
namespace vllm {
|
||||
template <typename scalar_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
|
||||
int32_t *sorted_token_ids,
|
||||
int32_t *expert_ids,
|
||||
int32_t *total_tokens_post_pad,
|
||||
int32_t num_experts,
|
||||
int32_t block_size,
|
||||
size_t numel) {
|
||||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
|
||||
__shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[threadIdx.x + 1][i] = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* In the first step we compute token_cnts[thread_index + 1][expert_index],
|
||||
* which counts how many tokens in the token shard of thread_index are assigned
|
||||
* to expert expert_index.
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
++tokens_cnts[threadIdx.x + 1][topk_ids[i]];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For each expert we accumulate the token counts from the different threads.
|
||||
tokens_cnts[0][threadIdx.x] = 0;
|
||||
for (int i = 1; i <= blockDim.x; ++i) {
|
||||
tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// We accumulate the token counts of all experts in thread 0.
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
/**
|
||||
* For each expert, each thread processes the tokens of the corresponding blocks
|
||||
* and stores the corresponding expert_id for each block.
|
||||
*/
|
||||
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
|
||||
/**
|
||||
* Each thread processes a token shard, calculating the index of each token after
|
||||
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
|
||||
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
|
||||
* where * represents a padding value(preset in python).
|
||||
*/
|
||||
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
/** The cumsum[expert_id] stores the starting index of the tokens that the
|
||||
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
|
||||
* stores the indices of the tokens processed by the expert with expert_id within
|
||||
* the current thread's token shard.
|
||||
*/
|
||||
int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
|
||||
sorted_token_ids[rank_post_pad] = i;
|
||||
++tokens_cnts[threadIdx.x][expert_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
assert(num_experts <= NUM_MAX_EXPERTS);
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(),
|
||||
num_experts,
|
||||
block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
}
|
130
csrc/ops.h
Normal file
@ -0,0 +1,130 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void paged_attention_v2(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& exp_sums,
|
||||
torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
int block_size,
|
||||
int max_context_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype);
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void fused_add_rms_norm(
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
float epsilon);
|
||||
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_new(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
torch::Tensor gptq_gemm(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama);
|
||||
|
||||
void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm);
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = uint64_t;
|
||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets, int rank,
|
||||
bool full_nvlink);
|
||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
||||
bool full_nvlink);
|
||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
||||
torch::Tensor &out);
|
||||
void dispose(fptr_t _fa);
|
||||
int meta_size();
|
||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
||||
const std::vector<std::string> &handles,
|
||||
const std::vector<int64_t> &offsets);
|
||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets);
|
||||
#endif
|
@ -1,15 +0,0 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& positions,
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key,
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"rotary_embedding_neox",
|
||||
&rotary_embedding_neox,
|
||||
"Apply GPT-NeoX style rotary embedding to query and key");
|
||||
}
|
@ -1,17 +1,53 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rotary_embedding_neox_kernel(
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
inline __device__ void apply_rotary_embedding(
|
||||
scalar_t* __restrict__ arr,
|
||||
const scalar_t* __restrict__ cos_ptr,
|
||||
const scalar_t* __restrict__ sin_ptr,
|
||||
int rot_offset,
|
||||
int embed_dim)
|
||||
{
|
||||
int x_index, y_index;
|
||||
scalar_t cos, sin;
|
||||
if (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = VLLM_LDG(cos_ptr + x_index);
|
||||
sin = VLLM_LDG(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
const scalar_t y = arr[y_index];
|
||||
arr[x_index] = x * cos - y * sin;
|
||||
arr[y_index] = y * cos + x * sin;
|
||||
}
|
||||
|
||||
template<typename scalar_t, bool IS_NEOX>
|
||||
__global__ void rotary_embedding_kernel(
|
||||
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
|
||||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
|
||||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
|
||||
const int rot_dim,
|
||||
const int stride,
|
||||
const int64_t query_stride,
|
||||
const int64_t key_stride,
|
||||
const int num_heads,
|
||||
const int num_kv_heads,
|
||||
const int head_size) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int token_idx = blockIdx.x;
|
||||
@ -19,65 +55,76 @@ __global__ void rotary_embedding_neox_kernel(
|
||||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
const int n = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||
const scalar_t* cos_ptr = cache_ptr;
|
||||
const scalar_t* sin_ptr = cache_ptr + embed_dim;
|
||||
|
||||
const int nq = num_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int token_head = token_idx * stride + head_idx * head_size;
|
||||
|
||||
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
const int x_index = rot_offset;
|
||||
const int y_index = embed_dim + rot_offset;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
|
||||
const int out_x = token_idx * stride + head_idx * head_size + x_index;
|
||||
const int out_y = token_idx * stride + head_idx * head_size + y_index;
|
||||
|
||||
const scalar_t cos = __ldg(cache_ptr + x_index);
|
||||
const scalar_t sin = __ldg(cache_ptr + y_index);
|
||||
|
||||
const scalar_t q_x = query[token_head + x_index];
|
||||
const scalar_t q_y = query[token_head + y_index];
|
||||
query[out_x] = q_x * cos - q_y * sin;
|
||||
query[out_y] = q_y * cos + q_x * sin;
|
||||
|
||||
const scalar_t k_x = key[token_head + x_index];
|
||||
const scalar_t k_y = key[token_head + y_index];
|
||||
key[out_x] = k_x * cos - k_y * sin;
|
||||
key[out_y] = k_y * cos + k_x * sin;
|
||||
const int nk = num_kv_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
|
||||
const int head_idx = i / embed_dim;
|
||||
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
|
||||
const int rot_offset = i % embed_dim;
|
||||
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
|
||||
sin_ptr, rot_offset, embed_dim);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& query, // [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads * head_size]
|
||||
void rotary_embedding(
|
||||
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
|
||||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
|
||||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
|
||||
int head_size,
|
||||
torch::Tensor& cos_sin_cache) // [max_position, rot_dim]
|
||||
{
|
||||
int num_tokens = query.size(0);
|
||||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
|
||||
bool is_neox) {
|
||||
int64_t num_tokens = query.numel() / query.size(-1);
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int num_heads = query.size(1) / head_size;
|
||||
int stride = query.stride(0);
|
||||
TORCH_CHECK(stride == key.stride(0));
|
||||
int num_heads = query.size(-1) / head_size;
|
||||
int num_kv_heads = key.size(-1) / head_size;
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * rot_dim / 2, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
"rotary_embedding",
|
||||
[&] {
|
||||
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
stride,
|
||||
num_heads,
|
||||
head_size);
|
||||
if (is_neox) {
|
||||
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
} else {
|
||||
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rot_dim,
|
||||
query_stride,
|
||||
key_stride,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
217
csrc/punica/LICENSE
Normal file
@ -0,0 +1,217 @@
|
||||
Contains code from https://github.com/punica-ai/punica
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
|
||||
|
||||
Apache-2.0
|
||||
* third_party/nvbench (with LLVM exception)
|
||||
* third_party/flashinfer
|
||||
|
||||
BSD-3-Clause:
|
||||
* third_party/cutlass
|
4
csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
|
4
csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
|
4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
|
59
csrc/punica/bgmv/bgmv_config.h
Normal file
@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale);
|
||||
|
||||
// clang-format off
|
||||
|
||||
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1280) \
|
||||
f(in_T, out_T, W_T, narrow, 1728) \
|
||||
f(in_T, out_T, W_T, narrow, 1792) \
|
||||
f(in_T, out_T, W_T, narrow, 2048) \
|
||||
f(in_T, out_T, W_T, narrow, 2560) \
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 3072) \
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
f(in_T, out_T, W_T, narrow, 5120) \
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
f(in_T, out_T, W_T, narrow, 8192) \
|
||||
f(in_T, out_T, W_T, narrow, 9216) \
|
||||
f(in_T, out_T, W_T, narrow, 10240) \
|
||||
f(in_T, out_T, W_T, narrow, 11008) \
|
||||
f(in_T, out_T, W_T, narrow, 12288) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
f(in_T, out_T, W_T, narrow, 32512) \
|
||||
f(in_T, out_T, W_T, narrow, 32768) \
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||
|
||||
// clang-format on
|
4
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
4
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
|
4
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
Normal file
@ -0,0 +1,4 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
|
294
csrc/punica/bgmv/bgmv_impl.cuh
Normal file
@ -0,0 +1,294 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cuda/pipeline>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "vec_dtypes.cuh"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
// nthrs = (32, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t num_pipeline_stages = 2;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
auto pipe = cuda::make_pipeline();
|
||||
|
||||
// pipeline load W/X and compute WX;
|
||||
pipe.producer_acquire();
|
||||
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
pipe.producer_commit();
|
||||
size_t copy_idx, compute_idx;
|
||||
float y = 0.f;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||
++tile_idx) {
|
||||
copy_idx = tile_idx % num_pipeline_stages;
|
||||
// pipeline stage: async copy W fragment
|
||||
pipe.producer_acquire();
|
||||
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
}
|
||||
pipe.producer_commit();
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// pipeline stage: compute WX
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] = sum;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
}
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// final pipeline stage
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] =
|
||||
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||
? sum
|
||||
: 0.f;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
|
||||
// write Y;
|
||||
if (block.thread_rank() == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||
}
|
||||
}
|
||||
|
||||
// nthrs = (2, 16, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||
typename in_T, typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t tile_idx = blockIdx.x;
|
||||
|
||||
// load X;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||
|
||||
// load W;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||
block.thread_rank() * vec_size);
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
|
||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += g.shfl_down(sum, offset);
|
||||
}
|
||||
sum = g.shfl(sum, 0);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||
}
|
||||
}
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
constexpr size_t vec_size = 8;
|
||||
constexpr int tz = 4;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if constexpr (feat_in < feat_out) {
|
||||
static_assert(feat_in % vec_size == 0);
|
||||
constexpr int tx = feat_in / vec_size;
|
||||
|
||||
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||
|
||||
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||
constexpr int ty = 32 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||
constexpr int ty = 16 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else {
|
||||
constexpr int ty = 8 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
} else {
|
||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||
feat_in % (vec_size * 16) == 0 ||
|
||||
feat_in % (vec_size * 8) == 0);
|
||||
|
||||
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||
vec_size * sizeof(W_T), tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||
constexpr int tx = 16;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||
template void bgmv_kernel<feat_in, feat_out>( \
|
||||
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||
int64_t num_layers, int64_t layer_idx, float scale);
|
||||
|
||||
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
27
csrc/punica/bgmv/generator.py
Normal file
@ -0,0 +1,27 @@
|
||||
DTYPES = ["fp16", "bf16", "fp32"]
|
||||
DTYPE_MAP = {
|
||||
"fp16": "nv_half",
|
||||
"bf16": "nv_bfloat16",
|
||||
"fp32": "float",
|
||||
}
|
||||
|
||||
TEMPLATE = """
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
""".lstrip()
|
||||
|
||||
for input_dtype in DTYPES:
|
||||
for output_dtype in DTYPES:
|
||||
for weight_dtype in DTYPES:
|
||||
if weight_dtype == "fp32":
|
||||
# FP32 weights are not supported.
|
||||
continue
|
||||
kernel_definition = TEMPLATE.format(
|
||||
input_dtype=DTYPE_MAP[input_dtype],
|
||||
output_dtype=DTYPE_MAP[output_dtype],
|
||||
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||
with open(filename, "w") as f:
|
||||
f.write(kernel_definition)
|
1324
csrc/punica/bgmv/vec_dtypes.cuh
Normal file
563
csrc/punica/punica_ops.cc
Normal file
@ -0,0 +1,563 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "bgmv/bgmv_config.h"
|
||||
|
||||
namespace {
|
||||
|
||||
//====== utils ======
|
||||
|
||||
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||
const char *a_name, const char *b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||
a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||
".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
return (uint32_t(a) << 16) | uint32_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) \
|
||||
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) \
|
||||
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
//====== bgmv ======
|
||||
|
||||
template <typename in_T, typename out_T, typename W_T>
|
||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
const int64_t *lora_indices,
|
||||
uint16_t in_features, uint16_t out_features,
|
||||
int64_t y_offset, int64_t full_y_size,
|
||||
int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
switch (pack_u16(in_features, out_features)) {
|
||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||
case pack_u16(feat_in, feat_out): \
|
||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||
full_y_size, batch_size, num_layers, \
|
||||
layer_idx, scale); \
|
||||
break;
|
||||
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||
#undef CASE
|
||||
#undef CASE_ONESIDE
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, float scale) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t h_in = x.size(1);
|
||||
int64_t h_out = y.size(1);
|
||||
int64_t num_layers = w.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
float scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t num_layers = w.size(1);
|
||||
int64_t full_y_size = y.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
//====== pybind ======
|
||||
|
||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
|
||||
m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
|
||||
"dispatch_bgmv_low_level");
|
||||
}
|
115
csrc/pybind.cpp
Normal file
@ -0,0 +1,115 @@
|
||||
#include "cache.h"
|
||||
#include "cuda_utils.h"
|
||||
#include "ops.h"
|
||||
#include <torch/extension.h>
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// vLLM custom ops
|
||||
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
|
||||
|
||||
// Attention ops
|
||||
ops.def(
|
||||
"paged_attention_v1",
|
||||
&paged_attention_v1,
|
||||
"Compute the attention between an input query and the cached keys/values using PagedAttention.");
|
||||
ops.def(
|
||||
"paged_attention_v2",
|
||||
&paged_attention_v2,
|
||||
"PagedAttention V2.");
|
||||
|
||||
// Activation ops
|
||||
ops.def(
|
||||
"silu_and_mul",
|
||||
&silu_and_mul,
|
||||
"Activation function used in SwiGLU.");
|
||||
ops.def(
|
||||
"gelu_new",
|
||||
&gelu_new,
|
||||
"GELU implementation used in GPT-2.");
|
||||
ops.def(
|
||||
"gelu_fast",
|
||||
&gelu_fast,
|
||||
"Approximate GELU implementation.");
|
||||
|
||||
// Layernorm
|
||||
ops.def(
|
||||
"rms_norm",
|
||||
&rms_norm,
|
||||
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||
|
||||
ops.def(
|
||||
"fused_add_rms_norm",
|
||||
&fused_add_rms_norm,
|
||||
"In-place fused Add and RMS Normalization");
|
||||
|
||||
// Rotary embedding
|
||||
ops.def(
|
||||
"rotary_embedding",
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
#endif
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
ops.def(
|
||||
"moe_align_block_size",
|
||||
&moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def(
|
||||
"swap_blocks",
|
||||
&swap_blocks,
|
||||
"Swap in (out) the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"copy_blocks",
|
||||
©_blocks,
|
||||
"Copy the cache blocks from src to dst");
|
||||
cache_ops.def(
|
||||
"reshape_and_cache",
|
||||
&reshape_and_cache,
|
||||
"Reshape the key and value tensors and cache them");
|
||||
cache_ops.def(
|
||||
"gather_cached_kv",
|
||||
&gather_cached_kv,
|
||||
"Gather key and value from the cache into contiguous QKV tensors");
|
||||
cache_ops.def(
|
||||
"convert_fp8_e5m2",
|
||||
&convert_fp8_e5m2,
|
||||
"Convert the key and value cache to fp8_e5m2 data type");
|
||||
|
||||
// Cuda utils
|
||||
pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
|
||||
cuda_utils.def(
|
||||
"get_device_attribute",
|
||||
&get_device_attribute,
|
||||
"Gets the specified device attribute.");
|
||||
|
||||
cuda_utils.def(
|
||||
"get_max_shared_memory_per_block_device_attribute",
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Custom all-reduce kernels
|
||||
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
||||
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
||||
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
||||
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
||||
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
||||
custom_ar.def("dispose", &dispose, "dispose");
|
||||
custom_ar.def("meta_size", &meta_size, "meta_size");
|
||||
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
||||
"get_graph_buffer_ipc_meta");
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||
"register_graph_buffers");
|
||||
#endif
|
||||
|
||||
}
|
87
csrc/quantization/awq/dequantize.cuh
Normal file
@ -0,0 +1,87 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
// Haotian: Let's use {-64, -64}.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
446
csrc/quantization/awq/gemm_kernels.cu
Normal file
@ -0,0 +1,446 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dequantize.cuh"
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
unsigned v0 = *((unsigned short *)&x);
|
||||
unsigned v1 = *((unsigned short *)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
template<int N>
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
|
||||
int G,
|
||||
int split_k_iters,
|
||||
half* __restrict__ A,
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
int M,
|
||||
int IC,
|
||||
int OC,
|
||||
half* __restrict__ C)
|
||||
{
|
||||
// Only support matrix n = 64 or 128
|
||||
assert(N == 64 || N == 128);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (N + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[N];
|
||||
__shared__ half zeros_shared[N];
|
||||
|
||||
int j_factors1 = ((OC + N - 1) / N);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[N / 4];
|
||||
for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * (256 / N)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
|
||||
+ (((int)threadIdx.x) / (N / 8)) * (N + 8)
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (N / 8)
|
||||
+ ((int)threadIdx.x) % (N / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ (((int)threadIdx.x) % (N / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ static_cast<long long>(blockIdx_z) * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * N
|
||||
+ ((int)threadIdx.y) * (N / 2)
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
|
||||
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
#else
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) dequantize_weights(
|
||||
int* __restrict__ B,
|
||||
half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros,
|
||||
half* __restrict__ C,
|
||||
int G
|
||||
)
|
||||
{
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
half B_shared[32 * (128 + 8)];
|
||||
|
||||
half* B_shared_ptr2 = B_shared;
|
||||
|
||||
half B_shared_warp[32];
|
||||
int OC = 512;
|
||||
|
||||
int N = blockDim.x * gridDim.x; // 2
|
||||
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int index1 = 8 * col + 8 * row * N;
|
||||
half* C_ptr2 = C + index1;
|
||||
|
||||
int index2 = col + row * N;
|
||||
int* B_ptr2 = B + index2;
|
||||
|
||||
int index3 = col + (int)(row / G) * N;
|
||||
int* zeros_ptr2 = zeros + index3;
|
||||
int index4 = 8 * col + (int)(row / G) * N * 8;
|
||||
half* scaling_factors_ptr2 = scaling_factors + index4;
|
||||
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
|
||||
|
||||
uint32_t B_loaded = *(uint32_t*)B_ptr2;
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
|
||||
*(uint4*)B_shared_ptr2 = B_loaded_fp16;
|
||||
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
*(C_ptr2 + i) = B_shared[i];
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor awq_dequantize(
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy)
|
||||
{
|
||||
int in_c = _kernel.size(0);
|
||||
int qout_c = _kernel.size(1);
|
||||
int out_c = qout_c * 8;
|
||||
int G = in_c / _scaling_factors.size(0);
|
||||
|
||||
int x_thread = thx;
|
||||
int y_thread = thy;
|
||||
|
||||
int x_blocks = 1;
|
||||
int y_blocks = 1;
|
||||
if (thx==0) {
|
||||
x_thread = qout_c;
|
||||
}
|
||||
if (thy==0) {
|
||||
y_thread = in_c;
|
||||
}
|
||||
if (thx==0 && thy==0) {
|
||||
x_thread = 8;
|
||||
y_thread = 8;
|
||||
x_blocks = (int)(qout_c / 8);
|
||||
y_blocks = (int)(in_c / 8);
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
|
||||
at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
|
||||
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
|
||||
dim3 num_blocks(x_blocks, y_blocks);
|
||||
dim3 threads_per_block(x_thread, y_thread);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
kernel, scaling_factors, zeros, de_kernel, G);
|
||||
|
||||
return _de_kernel;
|
||||
}
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||
// assume that batch_size < 16 for now
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters)
|
||||
{
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
|
||||
if (num_out_channels % 64 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||
if (num_out_channels % 8 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||
if (group_size % 32 != 0)
|
||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||
if (num_out_channels % group_size != 0)
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
if (num_out_channels % 128 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
|
||||
num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
277
csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
Normal file
@ -0,0 +1,277 @@
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <float.h>
|
||||
#include <type_traits>
|
||||
#include "../../attention/attention_dtypes.h"
|
||||
#include "../../attention/dtype_float32.cuh"
|
||||
#include "../../attention/dtype_float16.cuh"
|
||||
#include "../../attention/dtype_bfloat16.cuh"
|
||||
|
||||
|
||||
namespace vllm {
|
||||
#ifdef ENABLE_FP8_E5M2
|
||||
namespace fp8_e5m2_unscaled {
|
||||
|
||||
template<typename Tout, typename Tin>
|
||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
// fp8 -> half
|
||||
template<>
|
||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
return res.x;
|
||||
}
|
||||
|
||||
// fp8x2 -> half2
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
union {
|
||||
uint16_t u16[2];
|
||||
uint32_t u32;
|
||||
} tmp;
|
||||
__half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
|
||||
tmp.u16[0] = res.x;
|
||||
tmp.u16[1] = res.y;
|
||||
return tmp.u32;
|
||||
}
|
||||
|
||||
// fp8x4 -> half2x2
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
union {
|
||||
uint2 u32x2;
|
||||
uint32_t u32[2];
|
||||
} tmp;
|
||||
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
|
||||
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
|
||||
return tmp.u32x2;
|
||||
}
|
||||
|
||||
// fp8x8 -> half2x4
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
|
||||
{
|
||||
union {
|
||||
uint4 u64x2;
|
||||
uint2 u64[2];
|
||||
} tmp;
|
||||
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
|
||||
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
|
||||
return tmp.u64x2;
|
||||
}
|
||||
|
||||
// fp8 -> __nv_bfloat16
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// Note there is no direct convert function from fp8 to bf16.
|
||||
// fp8 -> half
|
||||
__half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
|
||||
// half -> float -> bf16
|
||||
float tmp = half_to_float(res.x);
|
||||
return __float2bfloat16(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> __nv_bfloat162
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__nv_bfloat162 res;
|
||||
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x4 -> bf16_4_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
bf16_4_t res;
|
||||
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> bf16_8_t
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
|
||||
{
|
||||
bf16_4_t tmp1, tmp2;
|
||||
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
|
||||
bf16_8_t res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8 -> float
|
||||
template<>
|
||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
|
||||
{
|
||||
// fp8 -> half
|
||||
uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
|
||||
// half -> float
|
||||
return half_to_float(tmp);
|
||||
}
|
||||
|
||||
// fp8x2 -> float2
|
||||
template<>
|
||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
// fp8x2 -> half2
|
||||
uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
|
||||
// half2 -> float2
|
||||
return half2_to_float2(tmp);
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ res;
|
||||
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
|
||||
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
|
||||
return res;
|
||||
}
|
||||
|
||||
// fp8x8 -> float8
|
||||
template<>
|
||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
|
||||
{
|
||||
Float4_ tmp1, tmp2;
|
||||
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
|
||||
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
|
||||
Float8_ res;
|
||||
res.x = tmp1.x;
|
||||
res.y = tmp1.y;
|
||||
res.z = tmp2.x;
|
||||
res.w = tmp2.y;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
// half -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
|
||||
{
|
||||
__half_raw tmp;
|
||||
tmp.x = a;
|
||||
__nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// bf16 -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
__nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
template<>
|
||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
|
||||
{
|
||||
__nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
|
||||
return (uint8_t)res;
|
||||
}
|
||||
|
||||
// fp8x4 -> float4
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
|
||||
{
|
||||
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
|
||||
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
|
||||
{
|
||||
union {
|
||||
half2 float16;
|
||||
uint32_t uint32;
|
||||
};
|
||||
|
||||
float16 = __float22half2_rn(a);
|
||||
return uint32;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
|
||||
{
|
||||
uint2 b;
|
||||
float2 val;
|
||||
val.x = a.x.x;
|
||||
val.y = a.x.y;
|
||||
b.x = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
val.x = a.y.x;
|
||||
val.y = a.y.y;
|
||||
b.y = vec_conversion<uint32_t, float2>(val);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
|
||||
{
|
||||
float4 b;
|
||||
b.x = a.x.x;
|
||||
b.y = a.x.y;
|
||||
b.z = a.y.x;
|
||||
b.w = a.y.y;
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
|
||||
{
|
||||
uint4 b;
|
||||
b.x = vec_conversion<uint32_t, float2>(a.x);
|
||||
b.y = vec_conversion<uint32_t, float2>(a.y);
|
||||
b.z = vec_conversion<uint32_t, float2>(a.z);
|
||||
b.w = vec_conversion<uint32_t, float2>(a.w);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
|
||||
__nv_bfloat162 b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
|
||||
bf16_4_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
template<>
|
||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
|
||||
bf16_8_t b;
|
||||
from_float(b, a);
|
||||
return b;
|
||||
}
|
||||
|
||||
} // namespace fp8_e5m2_unscaled
|
||||
#endif // ENABLE_FP8_E5M2
|
||||
} // namespace vllm
|
64
csrc/quantization/gptq/compat.cuh
Normal file
@ -0,0 +1,64 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
151
csrc/quantization/gptq/matrix_view.cuh
Normal file
@ -0,0 +1,151 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
|
||||
*/
|
||||
|
||||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_column
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (row & 0x07) * 4;
|
||||
return (data[row / 8 * width + column] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
|
||||
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
|
||||
};
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
875
csrc/quantization/gptq/q_gemm.cu
Normal file
@ -0,0 +1,875 @@
|
||||
/*
|
||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "compat.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "qdq_4.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
#define BLOCK_M_SIZE_MAX 8
|
||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||
#define MAX_Q_GEMM_ROWS 50
|
||||
#define MAX_ALT_GEMM_ROWS 8
|
||||
#define THREADS_X 32
|
||||
#define THREADS_Y 32
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
#include <hipblas/hipblas.h>
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
#endif
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hadd2(result, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
}
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int*
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_gptq_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int* __restrict__ b_q_perm
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
|
||||
half a0;
|
||||
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
else a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Zero output
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (blockIdx.z == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
int groupsize = size_k / groups;
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// a, b offset
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
int zeros[4];
|
||||
float scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
// Column result
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize and multiply
|
||||
int k = offset_k;
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
}
|
||||
|
||||
b_ptr += size_n;
|
||||
a_ptr += 8;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
const half* a,
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_q_perm,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int m_count,
|
||||
int groups
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
a,
|
||||
b_q_weight,
|
||||
b_gptq_qzeros,
|
||||
b_gptq_scales,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
groups,
|
||||
b_q_perm
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
__global__ void reconstruct_exllama_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const int* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groups,
|
||||
half* __restrict__ b
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
__shared__ int perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
|
||||
if (b_q_perm)
|
||||
{
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
}
|
||||
|
||||
// Column
|
||||
int n = offset_n + t * 4;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
int groupsize = size_k / groups;
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// b offset
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
// Initial zeros/scale
|
||||
int zeros[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4][4];
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
b_ptr += size_n;
|
||||
//half* dqh = (half*)dq;
|
||||
if (b_q_perm)
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void reconstruct_exllama
|
||||
(
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_q_perm,
|
||||
half* out,
|
||||
int height,
|
||||
int width,
|
||||
int groups
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
b_q_weight,
|
||||
b_q_perm,
|
||||
b_gptq_qzeros,
|
||||
b_gptq_scales,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
out
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
__global__ void gemm_half_q_half_alt_kernel(
|
||||
const half2* __restrict__ vec,
|
||||
const uint32_t* __restrict__ mat,
|
||||
half* __restrict__ mul,
|
||||
const half* __restrict__ scales,
|
||||
const uint32_t* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int height,
|
||||
int width
|
||||
)
|
||||
{
|
||||
int zero_width = width / 8;
|
||||
int vec_height = height * 4;
|
||||
const int blockwidth2 = BLOCK_KN_SIZE / 2;
|
||||
int b = blockIdx.y * BLOCK_M_SIZE_MAX;
|
||||
int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
|
||||
int h = BLOCK_KN_SIZE * blockIdx.z / 8;
|
||||
int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
|
||||
int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
|
||||
__shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
|
||||
if (threadIdx.x < h_end) {
|
||||
for (int m = 0; m < b_end; ++m) {
|
||||
blockvec[m][threadIdx.x] =
|
||||
vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
|
||||
threadIdx.x];
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ half2 deq2[256][8];
|
||||
int val = threadIdx.x / 8;
|
||||
int off = threadIdx.x % 8;
|
||||
for (; val < 256; val += BLOCK_KN_SIZE / 8) {
|
||||
deq2[val][off] = __halves2half2(
|
||||
__int2half_rn(val & 0xF), __int2half_rn(val >> 4)
|
||||
);
|
||||
}
|
||||
|
||||
if (blockIdx.z == 0)
|
||||
{
|
||||
for (int m = 0; m < b_end; m++)
|
||||
mul[(b + m) * width + w] = __int2half_rn(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
int i = width * h + w;
|
||||
int g_h = h * 8;
|
||||
int k = 0;
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
half2 res2;
|
||||
half res[BLOCK_M_SIZE_MAX] = {};
|
||||
|
||||
unsigned int tmp;
|
||||
while (k < h_end) {
|
||||
tmp = mat[i];
|
||||
half2 scales_tmp[4];
|
||||
half2 zeros_tmp[4];
|
||||
for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
|
||||
int g = g_idx[g_h + (k + tmp_k) * 2];
|
||||
int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
|
||||
half scale_f = scales[g * width + w];
|
||||
half scale_f2 = scales[g2 * width + w];
|
||||
half2 scale = __halves2half2(scale_f, scale_f2);
|
||||
half2 zero = __halves2half2(
|
||||
__hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
|
||||
__hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
|
||||
);
|
||||
scales_tmp[tmp_k] = scale;
|
||||
zeros_tmp[tmp_k] = zero;
|
||||
}
|
||||
for (int m = 0; m < b_end; m++) {
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
|
||||
#ifndef USE_ROCM
|
||||
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
|
||||
#else
|
||||
res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
|
||||
#endif
|
||||
}
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
for (int m = 0; m < b_end; m++) {
|
||||
atomicAdd(&mul[(b + m) * width + w], res[m]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void gemm_half_q_half_alt
|
||||
(
|
||||
const half* a,
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_g_idx,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
|
||||
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half2*) a,
|
||||
b_q_weight,
|
||||
c,
|
||||
b_gptq_scales,
|
||||
b_gptq_qzeros,
|
||||
b_g_idx,
|
||||
size_m,
|
||||
size_k / 8,
|
||||
size_n
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
__global__ void reconstruct_gptq_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
const half* __restrict__ w_scales,
|
||||
const uint32_t* __restrict__ w_zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
const int height,
|
||||
const int width,
|
||||
const int group,
|
||||
half* __restrict__ out
|
||||
)
|
||||
{
|
||||
// Start of block
|
||||
|
||||
int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
|
||||
int row = blockIdx.y * 8;
|
||||
if (column >= width) return;
|
||||
|
||||
// Views
|
||||
|
||||
MatrixView_q4_column w_(w, height, width);
|
||||
MatrixView_half_rw out_(out, height, width);
|
||||
MatrixView_half w_scales_(w_scales, group, width);
|
||||
MatrixView_q4_row w_zeros_(w_zeros, group, width);
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < 32; s += 4)
|
||||
{
|
||||
int group = g_idx[row + s / 4];
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
|
||||
*out_ptr = w_item; out_ptr += out_.width;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void reconstruct_gptq
|
||||
(
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_g_idx,
|
||||
half* out,
|
||||
int height,
|
||||
int width,
|
||||
int groups
|
||||
)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.y = DIVIDE(height, 8);
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
b_q_weight,
|
||||
b_gptq_scales,
|
||||
b_gptq_qzeros,
|
||||
b_g_idx,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
out
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
const uint32_t* b_q_weight,
|
||||
const uint32_t* b_gptq_qzeros,
|
||||
const half* b_gptq_scales,
|
||||
const int* b_g_idx,
|
||||
half* c,
|
||||
half* temp_dq,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int groups,
|
||||
bool use_exllama
|
||||
)
|
||||
{
|
||||
if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
if (use_exllama) {
|
||||
reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
|
||||
size_k, size_n, groups);
|
||||
}
|
||||
else
|
||||
{
|
||||
reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
temp_dq, size_k, size_n, groups);
|
||||
}
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = __float2half(0.0f);
|
||||
cublasHgemm(cublas_handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
size_n, size_m, size_k,
|
||||
&alpha, temp_dq, size_n,
|
||||
a, size_k,
|
||||
&beta, c, size_n);
|
||||
}
|
||||
else if (use_exllama)
|
||||
{
|
||||
// Quantized matmul
|
||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||
int last_chunk_size = size_m - last_chunk;
|
||||
|
||||
if (max_chunks)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
|
||||
groups);
|
||||
}
|
||||
|
||||
if (last_chunk_size)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
|
||||
b_gptq_scales, b_g_idx, c + last_chunk * size_n,
|
||||
last_chunk_size, size_n, size_k, last_chunk_size,
|
||||
groups);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
|
||||
c, size_m, size_n, size_k);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void shuffle_kernel
|
||||
(
|
||||
uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k,
|
||||
const int size_n
|
||||
)
|
||||
{
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||
}
|
||||
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const int* __restrict__ q_perm,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
int w_new2_row = blockIdx.y;
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = q_perm[q_perm_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
|
||||
void shuffle_exllama_weight
|
||||
(
|
||||
uint32_t* q_weight,
|
||||
int* q_perm,
|
||||
int height,
|
||||
int width
|
||||
)
|
||||
{
|
||||
if (q_perm)
|
||||
{
|
||||
uint32_t* new_qweight = NULL;
|
||||
cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = height / 8;
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
q_weight,
|
||||
new_qweight,
|
||||
q_perm,
|
||||
height / 8,
|
||||
width
|
||||
);
|
||||
// Replace qweights
|
||||
cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
// Cleanup
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(new_qweight);
|
||||
}
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = 1;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
torch::Tensor gptq_gemm
|
||||
(
|
||||
torch::Tensor a,
|
||||
torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
torch::Tensor b_gptq_scales,
|
||||
torch::Tensor b_g_idx,
|
||||
bool use_exllama
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
|
||||
at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
|
||||
at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
|
||||
|
||||
vllm::gptq::gemm_half_q_half_cuda
|
||||
(
|
||||
at::cuda::getCurrentCUDABlasHandle(),
|
||||
(const half*) a.data_ptr(),
|
||||
(const uint32_t*) b_q_weight.data_ptr(),
|
||||
(const uint32_t*)b_gptq_qzeros.data_ptr(),
|
||||
(const half*) b_gptq_scales.data_ptr(),
|
||||
b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
|
||||
(half*) c.data_ptr(),
|
||||
(half*) temp_dq.data_ptr(),
|
||||
c.size(0), // m
|
||||
c.size(1), // n
|
||||
a.size(1), // k
|
||||
b_gptq_qzeros.size(0), // group number
|
||||
use_exllama
|
||||
);
|
||||
return c;
|
||||
}
|
||||
|
||||
void gptq_shuffle
|
||||
(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||
vllm::gptq::shuffle_exllama_weight(
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_weight.size(0) * 8,
|
||||
q_weight.size(1)
|
||||
);
|
||||
}
|
235
csrc/quantization/gptq/qdq_4.cuh
Normal file
@ -0,0 +1,235 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#else
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z = __hmul(z, scale);
|
||||
z1[0] = __half2half2(z);
|
||||
y1[0] = __half2half2(scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1)[2],
|
||||
half2(&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z1[0] = __half2half2(z);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
half2 dqh2[8];
|
||||
|
||||
uint32_t qa = q_0;
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
dqh2[i] = __halves2half2(d0, d1);
|
||||
}
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
|
||||
#endif
|
60
csrc/quantization/gptq/qdq_util.cuh
Normal file
@ -0,0 +1,60 @@
|
||||
/*
|
||||
Copied from https://github.com/turboderp/exllamav2
|
||||
*/
|
||||
|
||||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
namespace vllm {
|
||||
namespace gptq {
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
} // namespace gptq
|
||||
} // namespace vllm
|
||||
#endif
|
225
csrc/quantization/squeezellm/quant_cuda_kernel.cu
Normal file
@ -0,0 +1,225 @@
|
||||
#include <torch/all.h>
|
||||
#include <torch/python.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
// half-tensor
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <ATen/cuda/CUDATensorMethods.cuh>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#define BLOCKWIDTH 128
|
||||
#define BLOCKHEIGHT4 16
|
||||
|
||||
namespace vllm {
|
||||
namespace squeezellm {
|
||||
|
||||
__device__ inline unsigned int as_unsigned(int i) {
|
||||
return *reinterpret_cast<unsigned int*>(&i);
|
||||
}
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
const half2* __restrict__ vec,
|
||||
#else
|
||||
const __half2* __restrict__ vec,
|
||||
#endif
|
||||
const int* __restrict__ mat,
|
||||
#ifndef USE_ROCM
|
||||
half2* __restrict__ mul,
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
int batch,
|
||||
int vec_height
|
||||
) {
|
||||
|
||||
const int blockwidth2 = BLOCKWIDTH / 2;
|
||||
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
#else
|
||||
__shared__ __half2 blockvec[blockwidth2];
|
||||
#endif
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
int column_offset = col * 16;
|
||||
for (int val = 0; val < 16; val += 1) {
|
||||
int lut_index = column_offset + val;
|
||||
deq2[val][off] = lookup_table[lut_index];
|
||||
}
|
||||
|
||||
__half res;
|
||||
#ifndef USE_ROCM
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
#else
|
||||
__half2 res2;
|
||||
__half2 tmp2;
|
||||
#endif
|
||||
|
||||
int i;
|
||||
int k;
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int lut_index1, lut_index2;
|
||||
|
||||
for (int b = 0; b < batch; ++b){
|
||||
i = width * row + col;
|
||||
res = __int2half_rd(0);
|
||||
k = 0;
|
||||
|
||||
__syncthreads();
|
||||
if (threadIdx.x < blockwidth2)
|
||||
blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
|
||||
__syncthreads();
|
||||
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
tmp2.x = __half_as_ushort(__float2half(0));
|
||||
tmp2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
#ifndef USE_ROCM
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
#else
|
||||
__half2 res3;
|
||||
res3.x = __half_as_ushort(__float2half(0));
|
||||
res3.y = __half_as_ushort(__float2half(0));
|
||||
if (col % 2 == 0) {
|
||||
res3.x = __half_as_ushort(res);
|
||||
} else {
|
||||
res3.y = __half_as_ushort(res);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
#else
|
||||
int tmp_addr = b * width / 2 + col / 2;
|
||||
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
||||
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace squeezellm
|
||||
} // namespace vllm
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
torch::Tensor mat,
|
||||
torch::Tensor mul,
|
||||
torch::Tensor lookup_table
|
||||
) {
|
||||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
|
||||
int batch = vec.size(0);
|
||||
int vec_height = vec.size(1);
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||
);
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*) vec.data<at::Half>(),
|
||||
#else
|
||||
(__half2*) vec.data_ptr<at::Half>(),
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
#else
|
||||
(float2*) mul.data_ptr<float>(),
|
||||
(__half*) lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
}
|
||||
|
||||
#undef BLOCKWIDTH
|
||||
#undef BLOCKHEIGHT4
|
@ -17,13 +17,15 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
|
Before Width: | Height: | Size: 267 KiB |
Before Width: | Height: | Size: 285 KiB |
Before Width: | Height: | Size: 259 KiB |
Before Width: | Height: | Size: 276 KiB |
Before Width: | Height: | Size: 244 KiB |
Before Width: | Height: | Size: 260 KiB |
Before Width: | Height: | Size: 255 KiB |
Before Width: | Height: | Size: 272 KiB |
@ -9,11 +9,15 @@
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
import os
|
||||
import sys
|
||||
from sphinx.ext import autodoc
|
||||
import logging
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
@ -21,7 +25,6 @@ project = 'vLLM'
|
||||
copyright = '2023, vLLM Team'
|
||||
author = 'the vLLM Team'
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
@ -32,6 +35,8 @@ extensions = [
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinx_copybutton",
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
@ -55,7 +60,6 @@ html_title = project
|
||||
html_theme = 'sphinx_book_theme'
|
||||
html_logo = 'assets/logos/vllm-logo-text-light.png'
|
||||
html_theme_options = {
|
||||
'logo_only': True,
|
||||
'path_to_docs': 'docs/source',
|
||||
'repository_url': 'https://github.com/vllm-project/vllm',
|
||||
'use_repository_button': True,
|
||||
@ -64,4 +68,31 @@ html_theme_options = {
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
# html_static_path = ['_static']
|
||||
|
||||
# Mock out external dependencies here.
|
||||
autodoc_mock_imports = [
|
||||
"torch", "transformers", "psutil", "aioprometheus", "sentencepiece",
|
||||
"vllm.cuda_utils", "vllm._C"
|
||||
]
|
||||
|
||||
for mock_target in autodoc_mock_imports:
|
||||
if mock_target in sys.modules:
|
||||
logger.info(
|
||||
f"Potentially problematic mock target ({mock_target}) found; "
|
||||
"autodoc_mock_imports cannot mock modules that have already "
|
||||
"been loaded into sys.modules when the sphinx build starts.")
|
||||
|
||||
|
||||
class MockedClassDocumenter(autodoc.ClassDocumenter):
|
||||
"""Remove note about base class when a class is derived from object."""
|
||||
|
||||
def add_line(self, line: str, source: str, *lineno: int) -> None:
|
||||
if line == " Bases: :py:class:`object`":
|
||||
return
|
||||
super().add_line(line, source, *lineno)
|
||||
|
||||
|
||||
autodoc.ClassDocumenter = MockedClassDocumenter
|
||||
|
||||
navigation_with_keys = False
|
||||
|
7
docs/source/dev/engine/async_llm_engine.rst
Normal file
@ -0,0 +1,7 @@
|
||||
|
||||
AsyncLLMEngine
|
||||
=================================
|
||||
|
||||
.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
|
||||
:members: generate, abort
|
||||
:show-inheritance:
|
13
docs/source/dev/engine/engine_index.rst
Normal file
@ -0,0 +1,13 @@
|
||||
vLLM Engine
|
||||
=================================
|
||||
|
||||
.. automodule:: vllm.engine
|
||||
.. currentmodule:: vllm.engine
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Engines
|
||||
|
||||
llm_engine
|
||||
async_llm_engine
|
||||
|
6
docs/source/dev/engine/llm_engine.rst
Normal file
@ -0,0 +1,6 @@
|
||||
LLMEngine
|
||||
=================================
|
||||
|
||||
.. autoclass:: vllm.engine.llm_engine.LLMEngine
|
||||
:members: add_request, abort_request, step, _init_cache
|
||||
:show-inheritance:
|
172
docs/source/getting_started/amd-installation.rst
Normal file
@ -0,0 +1,172 @@
|
||||
.. _installation_rocm:
|
||||
|
||||
Installation with ROCm
|
||||
======================
|
||||
|
||||
vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
|
||||
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
|
||||
Data types currently supported in ROCm are FP16 and BF16.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 -- 3.11
|
||||
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
|
||||
* Pytorch 2.0.1/2.1.1/2.2
|
||||
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
|
||||
|
||||
Installation options:
|
||||
|
||||
#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image <quick_start_docker_rocm>`
|
||||
#. :ref:`Build from source <build_from_source_rocm>`
|
||||
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
|
||||
|
||||
.. _quick_start_docker_rocm:
|
||||
|
||||
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
This option is for ROCm 5.7 only:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
embeddedllminfo/vllm-rocm \
|
||||
bash
|
||||
|
||||
|
||||
.. _build_from_source_rocm:
|
||||
|
||||
Option 2: Build from source
|
||||
---------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
Below instruction is for ROCm 5.7 only.
|
||||
At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website.
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
|
||||
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.23 --no-deps
|
||||
$ bash patch_xformers.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
|
||||
|
||||
|
||||
.. _build_from_source_docker_rocm:
|
||||
|
||||
Option 3: Build from source with docker
|
||||
-----------------------------------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
|
||||
|
||||
The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:
|
||||
|
||||
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
|
||||
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
|
||||
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
|
||||
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
|
||||
|
||||
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
|
||||
|
||||
For example, to build docker image for vllm on ROCm 5.7, you can run:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
|
||||
-f Dockerfile.rocm -t vllm-rocm .
|
||||
|
||||
To build vllm on ROCm 6.0, you can use the default:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
vllm-rocm \
|
||||
bash
|
||||
|
||||
Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.23 --no-deps
|
||||
$ bash patch_xformers.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes.
|
||||
|
||||
.. note::
|
||||
|
||||
- You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
|
||||
|
@ -3,30 +3,14 @@
|
||||
Installation
|
||||
============
|
||||
|
||||
vLLM is a Python library that also contains some C++ and CUDA code.
|
||||
This additional code requires compilation on the user's machine.
|
||||
vLLM is a Python library that also contains pre-compiled C++ and CUDA (12.1) binaries.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 or higher
|
||||
* CUDA: 11.0 -- 11.8
|
||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
|
||||
|
||||
.. note::
|
||||
As of now, vLLM does not support CUDA 12.
|
||||
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
|
||||
|
||||
.. tip::
|
||||
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
||||
|
||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
|
||||
* Python: 3.8 -- 3.11
|
||||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.)
|
||||
|
||||
Install with pip
|
||||
----------------
|
||||
@ -36,11 +20,31 @@ You can install vLLM using pip:
|
||||
.. code-block:: console
|
||||
|
||||
$ # (Optional) Create a new conda environment.
|
||||
$ conda create -n myenv python=3.8 -y
|
||||
$ conda create -n myenv python=3.9 -y
|
||||
$ conda activate myenv
|
||||
|
||||
$ # Install vLLM.
|
||||
$ pip install vllm # This may take 5-10 minutes.
|
||||
$ # Install vLLM with CUDA 12.1.
|
||||
$ pip install vllm
|
||||
|
||||
.. note::
|
||||
|
||||
As of now, vLLM's binaries are compiled on CUDA 12.1 by default.
|
||||
However, you can install vLLM with CUDA 11.8 by running:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Install vLLM with CUDA 11.8.
|
||||
$ export VLLM_VERSION=0.2.4
|
||||
$ export PYTHON_VERSION=39
|
||||
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
|
||||
|
||||
$ # Re-install PyTorch with CUDA 11.8.
|
||||
$ pip uninstall torch -y
|
||||
$ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
$ # Re-install xFormers with CUDA 11.8.
|
||||
$ pip uninstall xformers -y
|
||||
$ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
|
||||
.. _build_from_source:
|
||||
@ -55,3 +59,21 @@ You can also build and install vLLM from source:
|
||||
$ git clone https://github.com/vllm-project/vllm.git
|
||||
$ cd vllm
|
||||
$ pip install -e . # This may take 5-10 minutes.
|
||||
|
||||
.. tip::
|
||||
If you have trouble building vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Use `--ipc=host` to make sure the shared memory is large enough.
|
||||
$ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
|
||||
|
||||
.. note::
|
||||
If you are developing the C++ backend of vLLM, consider building vLLM with
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python setup.py develop
|
||||
|
||||
since it will give you incremental builds. The downside is that this method
|
||||
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
|
||||
|