mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
443 Commits
Author | SHA1 | Date | |
---|---|---|---|
0405645a6c | |||
41bf5612f5 | |||
a2769032ca | |||
f17f1d4608 | |||
1c1bb0bbf2 | |||
e0cc5f259a | |||
73aa6cfdf7 | |||
27b78c73ca | |||
b02fd288b2 | |||
ff7424f491 | |||
d93bf4da85 | |||
036ca94c25 | |||
ef001d98ef | |||
5f671cb4c3 | |||
bd02164cf9 | |||
46fb056749 | |||
dd6a3a02cb | |||
a7e3eba66f | |||
fbb5bd4cef | |||
80fcc3ed1c | |||
c386c43ca3 | |||
f26d790718 | |||
0f657bdc52 | |||
3fd1fb63ef | |||
925d2f1908 | |||
8f58a51358 | |||
2079e43bee | |||
e29d4358ef | |||
8cbc424975 | |||
dd66fd2b01 | |||
0f465ab533 | |||
23a7cbc88b | |||
426a5c3625 | |||
ddee88d0ff | |||
823ab79633 | |||
6116ca8cd7 | |||
2bc3fbba0c | |||
3f1fc7425a | |||
01ba927040 | |||
103bd17ac5 | |||
ce69f7f754 | |||
624a1e4711 | |||
372bf0890b | |||
5204ff5c3f | |||
0cc6b383d7 | |||
28e0750847 | |||
582cf78798 | |||
0034b09ceb | |||
72bac73067 | |||
68f11149d8 | |||
72f4880425 | |||
aa2cd2c43d | |||
9ddc35220b | |||
a5255270c3 | |||
0ee349b553 | |||
fa63e710c7 | |||
2a0309a646 | |||
324960a95c | |||
f1fc0510df | |||
bf21481dde | |||
fb30ee92ee | |||
221d388cc5 | |||
3132a933b6 | |||
df5dafaa5b | |||
ab5bbf5ae3 | |||
3bb8e2c9a2 | |||
e784c6b998 | |||
9a0f3bdbe5 | |||
c7c9851036 | |||
3c818bdb42 | |||
6dd94dbe94 | |||
0e74d797ce | |||
55ef66edf4 | |||
5e5630a478 | |||
d3d6bb13fb | |||
24b0205f58 | |||
c5cffcd0cd | |||
682b55bc07 | |||
9726ad676d | |||
eb5cb5e528 | |||
2cbeedad09 | |||
2c85529bfc | |||
e97f802b2d | |||
6e650f56a1 | |||
3f50c148fd | |||
8c01b8022c | |||
99d01a5e3d | |||
d07efb31c5 | |||
978b45f399 | |||
c5b4b11d7f | |||
8ae5ff2009 | |||
511627445e | |||
f0ef37233e | |||
7551a34032 | |||
01a55941f5 | |||
8d7aa9de71 | |||
68c4421b6d | |||
aea94362c9 | |||
7206ce4ce1 | |||
96f6a7596f | |||
84bee4bd5c | |||
fc66dee76d | |||
6609cdf019 | |||
16366ee8bb | |||
528dbcac7d | |||
cd7b6f0857 | |||
68ad4e3a8d | |||
4004f144f3 | |||
66818e5b63 | |||
222a9dc350 | |||
cbdc4ad5a5 | |||
016e3676e7 | |||
64ea24d0b3 | |||
df76e5af26 | |||
09ccc9c8f7 | |||
69196a9bc7 | |||
2acba47d9b | |||
9c485d9e25 | |||
fa9ee08121 | |||
347eeebe3b | |||
18fd4a8331 | |||
132a132100 | |||
1e60f87bb3 | |||
9705b90bcf | |||
3aec49e56f | |||
c64612802b | |||
9a7c3a0042 | |||
b197a5ccfd | |||
c81081fece | |||
a94eee4456 | |||
f2e9f2a3be | |||
1f1542afa9 | |||
96912550c8 | |||
2fc6944c5e | |||
5fe6bf29d6 | |||
d4b62d4641 | |||
ecf67814f1 | |||
750f4cabfa | |||
06a760d6e8 | |||
da7512215f | |||
af69a6aded | |||
7bd3630067 | |||
96663699b2 | |||
18572e3384 | |||
86bfb6dba7 | |||
5f0ec3935a | |||
c222f47992 | |||
170eb35079 | |||
b37d82791e | |||
3127e975fb | |||
4001ea1266 | |||
5c89a29c22 | |||
59a0192fb9 | |||
83609791d2 | |||
0974c9bc5c | |||
d2643128f7 | |||
c5c06209ec | |||
3ea7b94523 | |||
51ef828f10 | |||
df450aa567 | |||
bbe5f9de7d | |||
81763c58a0 | |||
edaae198e7 | |||
936db119ed | |||
e66faf4809 | |||
630eb5b5ce | |||
4e94951bb1 | |||
7a8a48d51e | |||
32eb0da808 | |||
6d0e3d3724 | |||
02798ecabe | |||
813f249f02 | |||
da02cb4b27 | |||
c09503ddd6 | |||
2b83503227 | |||
7b98a65ae6 | |||
b5b57e301e | |||
54cacf008f | |||
58fd57ff1d | |||
87a0c076af | |||
d4e6194570 | |||
07934cc237 | |||
69d765f5a5 | |||
8027a72461 | |||
d75ab55f10 | |||
d1adb9b403 | |||
b8bfa46a18 | |||
1475847a14 | |||
fead53ba78 | |||
ebc73f2828 | |||
d06e824006 | |||
62b06ba23d | |||
5fd24ec02e | |||
874f7c292a | |||
92e793d91a | |||
bf53e0c70b | |||
dd7c9ad870 | |||
9aa1519f08 | |||
f8ef146f03 | |||
fa0050db08 | |||
cd9d06fb8d | |||
ebd8c669ef | |||
70755e819e | |||
edce722eaa | |||
57e729e874 | |||
de0526f668 | |||
5ecf3e0aaf | |||
97eb97b5a4 | |||
3adf0ffda8 | |||
ad388d25a8 | |||
cbe94391eb | |||
994fc655b7 | |||
3f9b7ab9f5 | |||
ad34c0df0f | |||
f218f9c24d | |||
0794e7446e | |||
b7ee940a82 | |||
9ddac56311 | |||
1a51b9f872 | |||
42f5e7c52a | |||
a3a3ee4e6f | |||
87054a57ab | |||
c9d6ff530b | |||
a2d2acb4c8 | |||
2e0e017610 | |||
1f18adb245 | |||
bb354e6b2d | |||
ff39141a49 | |||
8a1f938e6f | |||
078da31903 | |||
1a401252b5 | |||
f35ec461fc | |||
289b5191d5 | |||
c6db21313c | |||
a7d59688fb | |||
458e63a2c6 | |||
e8c23ff989 | |||
cd8249903f | |||
0f8cafe2d1 | |||
5340a30d01 | |||
89ce62a316 | |||
c3f05b09a0 | |||
cf6bbcb493 | |||
80ea3af1a0 | |||
9dd02d85ca | |||
f7b3ba82c3 | |||
619ae268c3 | |||
d14e98d924 | |||
9597a095f2 | |||
263a870ee1 | |||
8bddb73512 | |||
f967e51f38 | |||
43f3d9e699 | |||
b25cfab9a0 | |||
4b657d3292 | |||
d697dc01b4 | |||
a991f7d508 | |||
7a3a83e3b8 | |||
c32a7c7c0c | |||
2118d0565c | |||
899136b857 | |||
c9f09a4fe8 | |||
d45cbe70f5 | |||
8a579408f3 | |||
46fa98ccad | |||
aa1e77a19c | |||
5959564f94 | |||
f33e033e27 | |||
482cdc494e | |||
20410b2fda | |||
12664ddda5 | |||
241ad7b301 | |||
d85c47d6ad | |||
ef725feafc | |||
d907be7dc7 | |||
d53575a5f0 | |||
61af633256 | |||
ac2f3f7fee | |||
cf5f000d21 | |||
3de2b1eafb | |||
b844b99ad3 | |||
c3cf54dda4 | |||
36f5303578 | |||
9a228348d2 | |||
bd82872211 | |||
405eb8e396 | |||
65097ca0af | |||
1d967acb45 | |||
0bd1ff4346 | |||
310aca88c9 | |||
a732900efc | |||
d848800e88 | |||
730e9592e9 | |||
1fe554bac3 | |||
615e4a5401 | |||
3db0cafdf1 | |||
526de822d5 | |||
56fe4c297c | |||
47de8821d3 | |||
5984499e47 | |||
ca47e176af | |||
78f4590b60 | |||
2f7024987e | |||
6cd40a5bfe | |||
aba8d6ee00 | |||
2a0596bc48 | |||
f12141170a | |||
cfd3219f58 | |||
a1b2b8606e | |||
ad9f1aa679 | |||
889e662eae | |||
ef68eb28d8 | |||
259abd8953 | |||
f645eb6954 | |||
f4923cb8bc | |||
b640b19cc0 | |||
dc71af0a71 | |||
4d29e91be8 | |||
91445c7bc8 | |||
5950f555a1 | |||
a4e2b26856 | |||
973f5dc581 | |||
c994223d56 | |||
869579a702 | |||
c0efe92d8b | |||
d9fa1c05ad | |||
2de197bdd4 | |||
869e829b85 | |||
8f37be38eb | |||
8082ad7950 | |||
1e4ce295ae | |||
ce1917fcf2 | |||
e512f76a89 | |||
898cdf033e | |||
0f3f3c86ec | |||
b278557935 | |||
8ceffbf315 | |||
d93d2d74fd | |||
d0169e1b0f | |||
08fb75c72e | |||
91b361ae89 | |||
e20c92bb61 | |||
32c9eff2ff | |||
4ca5d40adc | |||
9279b9f83d | |||
ee77fdb5de | |||
996357e480 | |||
2a622d704a | |||
9c749713f6 | |||
022c5c6944 | |||
f8fcca100b | |||
06bfb51963 | |||
408e560015 | |||
402d378360 | |||
9e764e7b10 | |||
33fc1e2e86 | |||
eba17173d3 | |||
635b897246 | |||
4068f4b5b5 | |||
47831430cc | |||
65c08928c2 | |||
ba214dffbe | |||
eed11ebee9 | |||
300acb8347 | |||
d91457d529 | |||
fbf2564554 | |||
d1d49397e7 | |||
9c93636d84 | |||
e5d7ed0c53 | |||
ad0d567e1c | |||
bf0d97d786 | |||
a655eb3025 | |||
1543914c04 | |||
61fed92c7e | |||
80c751e7f6 | |||
e1a5c2f0a1 | |||
fd3a62a122 | |||
07064cb1d4 | |||
2f1e8e8f54 | |||
68d37809b9 | |||
5dba257506 | |||
187e32997c | |||
b55ed6ef8a | |||
2f385183f3 | |||
84c35c374a | |||
8c38ee7007 | |||
b6087a6bee | |||
23c1b10a4c | |||
a115ac46b5 | |||
73001445fb | |||
6d70198b17 | |||
f962f426bc | |||
11d8a091c6 | |||
365801fedd | |||
4db72e57f6 | |||
0c6f998554 | |||
e7c7c5e822 | |||
8c3230d8c1 | |||
2c5718809b | |||
82c49d3260 | |||
74fa1d123c | |||
a2a40bcd0d | |||
ccb1aabcca | |||
36e7670045 | |||
5886aa496e | |||
8d9b6721e7 | |||
b12e87f942 | |||
5dbf854553 | |||
970d6d0776 | |||
628ec6c17b | |||
3682e33f9f | |||
0aa38d16f5 | |||
faef77c0d6 | |||
dba4d9dec6 | |||
32b4c63f02 | |||
4fb8e329fd | |||
328841d002 | |||
d427e5cfda | |||
42bb201fd6 | |||
59d6bb4c86 | |||
b7dcc003dc | |||
d34be24bb1 | |||
b5cbe8eeb3 | |||
df04dffade | |||
a60731247f | |||
ac79799403 | |||
dde1fa18c9 | |||
0240402c46 | |||
55509c2114 | |||
101418096f | |||
5ce4627a7e | |||
7af553ea30 | |||
2c9b8ea2b0 | |||
d003f3ea39 | |||
6c6f7fe8a8 | |||
2339d59f92 | |||
1b875a0ef3 | |||
eb881ed006 | |||
46d4359450 | |||
81b979f2a8 | |||
371d04d39b | |||
0c0c2015c5 | |||
82d24f7aac |
@ -2,8 +2,11 @@ import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
|
||||
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 300 MiB
|
||||
# Note that we have 400 MiB quota, please use it wisely.
|
||||
# See https://github.com/pypi/support/issues/3792 .
|
||||
# Please also sync the value with the one in Dockerfile.
|
||||
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 300))
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
|
@ -1,5 +1,6 @@
|
||||
steps:
|
||||
- label: "Wait for container to be ready"
|
||||
key: wait-for-container-image
|
||||
agents:
|
||||
queue: A100
|
||||
plugins:
|
||||
@ -10,12 +11,11 @@ steps:
|
||||
command:
|
||||
- sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh
|
||||
|
||||
- wait
|
||||
|
||||
- label: "A100"
|
||||
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
|
||||
agents:
|
||||
queue: A100
|
||||
depends_on: wait-for-container-image
|
||||
plugins:
|
||||
- kubernetes:
|
||||
podSpec:
|
||||
@ -49,6 +49,7 @@ steps:
|
||||
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
|
||||
agents:
|
||||
queue: H200
|
||||
depends_on: wait-for-container-image
|
||||
plugins:
|
||||
- docker#v5.12.0:
|
||||
image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
|
||||
@ -73,7 +74,7 @@ steps:
|
||||
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
|
||||
agents:
|
||||
queue: H100
|
||||
depends_on: block-h100
|
||||
depends_on: wait-for-container-image
|
||||
plugins:
|
||||
- docker#v5.12.0:
|
||||
image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT
|
||||
|
@ -43,7 +43,7 @@ main() {
|
||||
|
||||
|
||||
|
||||
# The figures should be genereated by a separate process outside the CI/CD pipeline
|
||||
# The figures should be generated by a separate process outside the CI/CD pipeline
|
||||
|
||||
# # generate figures
|
||||
# python3 -m pip install tabulate pandas matplotlib
|
||||
|
@ -301,6 +301,104 @@ run_serving_tests() {
|
||||
kill_gpu_processes
|
||||
}
|
||||
|
||||
run_genai_perf_tests() {
|
||||
# run genai-perf tests
|
||||
|
||||
# $1: a json file specifying genai-perf test cases
|
||||
local genai_perf_test_file
|
||||
genai_perf_test_file=$1
|
||||
|
||||
# Iterate over genai-perf tests
|
||||
jq -c '.[]' "$genai_perf_test_file" | while read -r params; do
|
||||
# get the test name, and append the GPU type back to it.
|
||||
test_name=$(echo "$params" | jq -r '.test_name')
|
||||
|
||||
# if TEST_SELECTOR is set, only run the test cases that match the selector
|
||||
if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then
|
||||
echo "Skip test case $test_name."
|
||||
continue
|
||||
fi
|
||||
|
||||
# prepend the current serving engine to the test name
|
||||
test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name}
|
||||
|
||||
# get common parameters
|
||||
common_params=$(echo "$params" | jq -r '.common_parameters')
|
||||
model=$(echo "$common_params" | jq -r '.model')
|
||||
tp=$(echo "$common_params" | jq -r '.tp')
|
||||
dataset_name=$(echo "$common_params" | jq -r '.dataset_name')
|
||||
dataset_path=$(echo "$common_params" | jq -r '.dataset_path')
|
||||
port=$(echo "$common_params" | jq -r '.port')
|
||||
num_prompts=$(echo "$common_params" | jq -r '.num_prompts')
|
||||
reuse_server=$(echo "$common_params" | jq -r '.reuse_server')
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters")
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
|
||||
if [[ $reuse_server == "true" ]]; then
|
||||
echo "Reuse previous server for test case $test_name"
|
||||
else
|
||||
kill_gpu_processes
|
||||
bash "$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh" \
|
||||
"$server_params" "$common_params"
|
||||
fi
|
||||
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "$CURRENT_LLM_SERVING_ENGINE server is up and running."
|
||||
else
|
||||
echo ""
|
||||
echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period."
|
||||
break
|
||||
fi
|
||||
|
||||
# iterate over different QPS
|
||||
for qps in $qps_list; do
|
||||
# remove the surrounding single quote from qps
|
||||
if [[ "$qps" == *"inf"* ]]; then
|
||||
echo "qps was $qps"
|
||||
qps=$num_prompts
|
||||
echo "now qps is $qps"
|
||||
fi
|
||||
|
||||
new_test_name=$test_name"_qps_"$qps
|
||||
backend=$CURRENT_LLM_SERVING_ENGINE
|
||||
|
||||
if [[ "$backend" == *"vllm"* ]]; then
|
||||
backend="vllm"
|
||||
fi
|
||||
#TODO: add output dir.
|
||||
client_command="genai-perf profile \
|
||||
-m $model \
|
||||
--service-kind openai \
|
||||
--backend vllm \
|
||||
--endpoint-type chat \
|
||||
--streaming \
|
||||
--url localhost:$port \
|
||||
--request-rate $qps \
|
||||
--num-prompts $num_prompts \
|
||||
"
|
||||
|
||||
echo "Client command: $client_command"
|
||||
|
||||
eval "$client_command"
|
||||
|
||||
#TODO: process/record outputs
|
||||
done
|
||||
done
|
||||
|
||||
kill_gpu_processes
|
||||
|
||||
}
|
||||
|
||||
prepare_dataset() {
|
||||
|
||||
@ -328,12 +426,17 @@ main() {
|
||||
|
||||
pip install -U transformers
|
||||
|
||||
pip install -r requirements-dev.txt
|
||||
which genai-perf
|
||||
|
||||
# check storage
|
||||
df -h
|
||||
|
||||
ensure_installed wget
|
||||
ensure_installed curl
|
||||
ensure_installed jq
|
||||
# genai-perf dependency
|
||||
ensure_installed libb64-0d
|
||||
|
||||
prepare_dataset
|
||||
|
||||
@ -345,6 +448,10 @@ main() {
|
||||
# run the test
|
||||
run_serving_tests "$BENCHMARK_ROOT/tests/nightly-tests.json"
|
||||
|
||||
# run genai-perf tests
|
||||
run_genai_perf_tests "$BENCHMARK_ROOT/tests/genai-perf-tests.json"
|
||||
mv artifacts/ $RESULTS_FOLDER/
|
||||
|
||||
# upload benchmark results to buildkite
|
||||
python3 -m pip install tabulate pandas
|
||||
python3 "$BENCHMARK_ROOT/scripts/summary-nightly-results.py"
|
||||
|
23
.buildkite/nightly-benchmarks/tests/genai-perf-tests.json
Normal file
23
.buildkite/nightly-benchmarks/tests/genai-perf-tests.json
Normal file
@ -0,0 +1,23 @@
|
||||
[
|
||||
{
|
||||
"test_name": "llama8B_tp1_genai_perf",
|
||||
"qps_list": [4,8,16,32],
|
||||
"common_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"tp": 1,
|
||||
"port": 8000,
|
||||
"num_prompts": 500,
|
||||
"reuse_server": false
|
||||
},
|
||||
"vllm_server_parameters": {
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"gpu_memory_utilization": 0.9,
|
||||
"num_scheduler_steps": 10,
|
||||
"max_num_seqs": 512,
|
||||
"dtype": "bfloat16"
|
||||
},
|
||||
"genai_perf_input_parameters": {
|
||||
}
|
||||
}
|
||||
]
|
@ -9,36 +9,33 @@ CORE_RANGE=${CORE_RANGE:-48-95}
|
||||
NUMA_NODE=${NUMA_NODE:-1}
|
||||
|
||||
# Try building the docker image
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test -f Dockerfile.cpu .
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build -t cpu-test-"$BUILDKITE_BUILD_NUMBER" -f Dockerfile.cpu .
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2 -f Dockerfile.cpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test-"$NUMA_NODE" cpu-test-avx2-"$NUMA_NODE" || true; }
|
||||
remove_docker_container() { set -e; docker rm -f cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2-"$NUMA_NODE" cpu-test-avx2
|
||||
--cpuset-mems="$NUMA_NODE" --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
export NUMA_NODE=$2
|
||||
|
||||
# offline inference
|
||||
docker exec cpu-test-avx2-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-avx2-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
python3 examples/offline_inference.py"
|
||||
python3 examples/offline_inference/basic.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pip install pytest pytest-asyncio \
|
||||
decord einops librosa peft Pillow sentence-transformers soundfile \
|
||||
transformers_stream_generator matplotlib datamodel_code_generator
|
||||
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -r vllm/requirements-test.txt
|
||||
pytest -v -s tests/models/decoder_only/language -m cpu_model
|
||||
pytest -v -s tests/models/embedding/language -m cpu_model
|
||||
pytest -v -s tests/models/encoder_decoder/language -m cpu_model
|
||||
@ -46,26 +43,26 @@ function cpu_tests() {
|
||||
pytest -v -s tests/models/decoder_only/vision_language -m cpu_model"
|
||||
|
||||
# Run compressed-tensor test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
||||
|
||||
# Run AWQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v -k cpu_model \
|
||||
tests/basic_correctness/test_chunked_prefill.py"
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# online serving
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
export VLLM_CPU_KVCACHE_SPACE=10
|
||||
export VLLM_CPU_OMP_THREADS_BIND=$1
|
||||
@ -78,8 +75,14 @@ function cpu_tests() {
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
||||
|
||||
# Run multi-lora tests
|
||||
docker exec cpu-test-"$BUILDKITE_BUILD_NUMBER"-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/lora/test_qwen2vl.py"
|
||||
}
|
||||
|
||||
# All of CPU tests are expected to be finished less than 25 mins.
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
@ -24,5 +24,5 @@ remove_docker_container
|
||||
|
||||
# Run the image and test offline inference
|
||||
docker run --name gh200-test --gpus=all --entrypoint="" gh200-test bash -c '
|
||||
python3 examples/offline_inference.py
|
||||
python3 examples/offline_inference/basic.py
|
||||
'
|
||||
|
@ -8,9 +8,17 @@ set -ex
|
||||
docker build -t hpu-test-env -f Dockerfile.hpu .
|
||||
|
||||
# Setup cleanup
|
||||
# certain versions of HPU software stack have a bug that can
|
||||
# override the exit code of the script, so we need to use
|
||||
# separate remove_docker_container and remove_docker_container_and_exit
|
||||
# functions, while other platforms only need one remove_docker_container
|
||||
# function.
|
||||
EXITCODE=1
|
||||
remove_docker_container() { docker rm -f hpu-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; }
|
||||
trap remove_docker_container_and_exit EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference.py
|
||||
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py
|
||||
EXITCODE=$?
|
||||
|
@ -3,6 +3,18 @@
|
||||
# This script build the Neuron docker image and run the API server inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -e
|
||||
set -v
|
||||
|
||||
image_name="neuron/vllm-ci"
|
||||
container_name="neuron_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||
|
||||
HF_CACHE="$(realpath ~)/huggingface"
|
||||
mkdir -p "${HF_CACHE}"
|
||||
HF_MOUNT="/root/.cache/huggingface"
|
||||
|
||||
NEURON_COMPILE_CACHE_URL="$(realpath ~)/neuron_compile_cache"
|
||||
mkdir -p "${NEURON_COMPILE_CACHE_URL}"
|
||||
NEURON_COMPILE_CACHE_MOUNT="/root/.cache/neuron_compile_cache"
|
||||
|
||||
# Try building the docker image
|
||||
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
|
||||
@ -13,41 +25,33 @@ if [ -f /tmp/neuron-docker-build-timestamp ]; then
|
||||
last_build=$(cat /tmp/neuron-docker-build-timestamp)
|
||||
current_time=$(date +%s)
|
||||
if [ $((current_time - last_build)) -gt 86400 ]; then
|
||||
docker system prune -f
|
||||
# Remove dangling images (those that are not tagged and not used by any container)
|
||||
docker image prune -f
|
||||
# Remove unused volumes / force the system prune for old images as well.
|
||||
docker volume prune -f && docker system prune -f
|
||||
# Remove huggingface model artifacts and compiler cache
|
||||
rm -rf "${HF_MOUNT:?}/*"
|
||||
rm -rf "${NEURON_COMPILE_CACHE_MOUNT:?}/*"
|
||||
echo "$current_time" > /tmp/neuron-docker-build-timestamp
|
||||
fi
|
||||
else
|
||||
date "+%s" > /tmp/neuron-docker-build-timestamp
|
||||
fi
|
||||
|
||||
docker build -t neuron -f Dockerfile.neuron .
|
||||
docker build -t "${image_name}" -f Dockerfile.neuron .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f neuron || true; }
|
||||
remove_docker_container() {
|
||||
docker image rm -f "${image_name}" || true;
|
||||
}
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image
|
||||
docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \
|
||||
--model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 &
|
||||
|
||||
# Wait for the server to start
|
||||
wait_for_server_to_start() {
|
||||
timeout=300
|
||||
counter=0
|
||||
|
||||
while [ "$(curl -s -o /dev/null -w '%{http_code}' localhost:8000/health)" != "200" ]; do
|
||||
sleep 1
|
||||
counter=$((counter + 1))
|
||||
if [ $counter -ge $timeout ]; then
|
||||
echo "Timeout after $timeout seconds"
|
||||
break
|
||||
fi
|
||||
done
|
||||
}
|
||||
wait_for_server_to_start
|
||||
|
||||
# Test a simple prompt
|
||||
curl -X POST -H "Content-Type: application/json" \
|
||||
localhost:8000/generate \
|
||||
-d '{"prompt": "San Francisco is a"}'
|
||||
docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \
|
||||
-v "${HF_CACHE}:${HF_MOUNT}" \
|
||||
-e "HF_HOME=${HF_MOUNT}" \
|
||||
-v "${NEURON_COMPILE_CACHE_URL}:${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
-e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys"
|
||||
|
@ -13,4 +13,4 @@ trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference.py
|
||||
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference/basic.py
|
||||
|
11
.buildkite/run-tpu-test.sh
Normal file → Executable file
11
.buildkite/run-tpu-test.sh
Normal file → Executable file
@ -14,4 +14,13 @@ remove_docker_container
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it -e "HF_TOKEN=$HF_TOKEN" --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && python3 -m pip install lm_eval[api]==0.4.4 && pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"
|
||||
docker run --privileged --net host --shm-size=16G -it \
|
||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install pytest \
|
||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/openai/test_accuracy.py \
|
||||
&& pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
&& python3 /workspace/vllm/tests/tpu/test_compilation.py \
|
||||
&& python3 /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
|
||||
|
@ -14,6 +14,6 @@ remove_docker_container
|
||||
|
||||
# Run the image and test offline inference/tensor parallel
|
||||
docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c '
|
||||
python3 examples/offline_inference.py
|
||||
python3 examples/offline_inference_cli.py -tp 2
|
||||
python3 examples/offline_inference/basic.py
|
||||
python3 examples/offline_inference/cli.py -tp 2
|
||||
'
|
||||
|
@ -38,7 +38,7 @@ steps:
|
||||
- pip install -r requirements-docs.txt
|
||||
- SPHINXOPTS=\"-W\" make html
|
||||
# Check API reference (if it fails, you may have missing mock imports)
|
||||
- grep \"sig sig-object py\" build/html/dev/sampling_params.html
|
||||
- grep \"sig sig-object py\" build/html/api/inference_params.html
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
fast_check: true
|
||||
@ -76,7 +76,9 @@ steps:
|
||||
- tests/basic_correctness/test_basic_correctness
|
||||
- tests/basic_correctness/test_cpu_offload
|
||||
- tests/basic_correctness/test_preemption
|
||||
- tests/basic_correctness/test_cumem.py
|
||||
commands:
|
||||
- pytest -v -s basic_correctness/test_cumem.py
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
@ -106,14 +108,12 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
@ -127,11 +127,15 @@ steps:
|
||||
- tests/distributed
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile
|
||||
- examples/offline_inference/rlhf.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- python3 ../examples/offline_inference/rlhf.py
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
num_gpus: 2
|
||||
@ -179,7 +183,16 @@ steps:
|
||||
- vllm/
|
||||
- tests/v1
|
||||
commands:
|
||||
- VLLM_USE_V1=1 pytest -v -s v1
|
||||
# split the test to avoid interference
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/core
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/e2e
|
||||
|
||||
- label: Examples Test # 25min
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -189,19 +202,19 @@ steps:
|
||||
- examples/
|
||||
commands:
|
||||
- pip install tensorizer # for tensorizer test
|
||||
- python3 offline_inference.py
|
||||
- python3 cpu_offload.py
|
||||
- python3 offline_inference_chat.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 offline_inference_vision_language.py
|
||||
- python3 offline_inference_vision_language_multi_image.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference_encoder_decoder.py
|
||||
- python3 offline_inference_classification.py
|
||||
- python3 offline_inference_embedding.py
|
||||
- python3 offline_inference_scoring.py
|
||||
- python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
- python3 offline_inference/basic.py
|
||||
- python3 offline_inference/cpu_offload.py
|
||||
- python3 offline_inference/chat.py
|
||||
- python3 offline_inference/prefix_caching.py
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/vision_language.py
|
||||
- python3 offline_inference/vision_language_multi_image.py
|
||||
- python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
- python3 offline_inference/classification.py
|
||||
- python3 offline_inference/embedding.py
|
||||
- python3 offline_inference/scoring.py
|
||||
- python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
|
||||
|
||||
- label: Prefix Caching Test # 9min
|
||||
mirror_hardwares: [amd]
|
||||
@ -216,6 +229,7 @@ steps:
|
||||
- vllm/model_executor/layers
|
||||
- vllm/sampling_metadata.py
|
||||
- tests/samplers
|
||||
- tests/conftest.py
|
||||
commands:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
@ -231,20 +245,22 @@ steps:
|
||||
- pytest -v -s test_logits_processor.py
|
||||
- pytest -v -s model_executor/test_guided_processors.py
|
||||
|
||||
- label: Speculative decoding tests # 30min
|
||||
- label: Speculative decoding tests # 40min
|
||||
source_file_dependencies:
|
||||
- vllm/spec_decode
|
||||
- tests/spec_decode
|
||||
- vllm/model_executor/models/eagle.py
|
||||
commands:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
|
||||
parallelism: 4
|
||||
|
||||
- label: "PyTorch Fullgraph Smoke Test" # 9min
|
||||
@ -333,8 +349,6 @@ steps:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
|
||||
@ -360,23 +374,26 @@ steps:
|
||||
- pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model'
|
||||
- pytest -v -s models/embedding/language -m 'not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Standard) # 28min
|
||||
- label: Multi-Modal Models Test (Standard) # 40min
|
||||
#mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/decoder_only/audio_language
|
||||
- tests/models/decoder_only/vision_language
|
||||
- tests/models/embedding/vision_language
|
||||
- tests/models/encoder_decoder/audio_language
|
||||
- tests/models/encoder_decoder/vision_language
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal
|
||||
- pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model'
|
||||
- pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model'
|
||||
- pytest -v -s models/embedding/vision_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/audio_language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/language -m core_model
|
||||
- pytest -v -s models/encoder_decoder/vision_language -m core_model
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 1h16m
|
||||
- label: Multi-Modal Models Test (Extended) 1 # 48m
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -459,7 +476,10 @@ steps:
|
||||
- vllm/worker/worker_base.py
|
||||
- vllm/worker/worker.py
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
@ -468,12 +488,31 @@ steps:
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
fast_check: true
|
||||
source_file_dependencies:
|
||||
- vllm/plugins/
|
||||
- tests/plugins/
|
||||
commands:
|
||||
# begin platform plugin tests, all the code in-between runs on dummy platform
|
||||
- pip install -e ./plugins/vllm_add_dummy_platform
|
||||
- pytest -v -s plugins_tests/test_platform_plugins.py
|
||||
- pip uninstall vllm_add_dummy_platform -y
|
||||
# end platform plugin tests
|
||||
# other tests continue here:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pytest -v -s distributed/test_distributed_oot.py
|
||||
- pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process
|
||||
- pytest -v -s models/test_oot_registration.py # it needs a clean process
|
||||
|
||||
- label: Multi-step Tests (4 GPUs) # 36min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
@ -489,7 +528,9 @@ steps:
|
||||
- vllm/engine
|
||||
- tests/multi_step
|
||||
commands:
|
||||
- pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
# this test is quite flaky
|
||||
# TODO: investigate and fix.
|
||||
# - pytest -v -s multi_step/test_correctness_async_llm.py
|
||||
- pytest -v -s multi_step/test_correctness_llm.py
|
||||
|
||||
- label: Pipeline Parallelism Test # 45min
|
||||
@ -520,6 +561,7 @@ steps:
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
- pytest -v -s -x lora/test_llama_tp.py
|
||||
- pytest -v -s -x lora/test_minicpmv_tp.py
|
||||
|
||||
|
||||
- label: Weight Loading Multiple GPU Test # 33min
|
||||
|
27
.github/CODEOWNERS
vendored
27
.github/CODEOWNERS
vendored
@ -2,32 +2,35 @@
|
||||
# for more info about CODEOWNERS file
|
||||
|
||||
# This lists cover the "core" components of vLLM that require careful review
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill
|
||||
/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/core @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/engine/llm_engine.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/executor/executor_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker_base.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill
|
||||
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
|
||||
/vllm/model_executor/guided_decoding @mgoin
|
||||
/vllm/multimodal @DarkLight1337 @ywang96
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-neuralmagic @njhill @ywang96 @comaniac @alexm-neuralmagic
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
|
||||
# Test ownership
|
||||
/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo
|
||||
/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo
|
||||
/tests/models @DarkLight1337 @ywang96
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/kernels @tlrmchlsmth @WoosukKwon
|
||||
/tests/quantization @mgoin @robertgshaw2-neuralmagic
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/distributed/test_multi_node_assignment.py @youkaichao
|
||||
/tests/distributed/test_pipeline_parallel.py @youkaichao
|
||||
/tests/distributed/test_same_node.py @youkaichao
|
||||
/tests/multi_step @alexm-neuralmagic @comaniac
|
||||
/tests/multi_step @alexm-redhat @comaniac
|
||||
/tests/weight_loading @mgoin @youkaichao
|
||||
/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac
|
||||
|
@ -9,7 +9,7 @@ body:
|
||||
value: >
|
||||
#### Before submitting an issue, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/models/adding_model.html first to understand how to add a new model.
|
||||
#### We also highly recommend you read https://docs.vllm.ai/en/latest/contributing/model/adding_model.html first to understand how to add a new model.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: The model to consider.
|
40
.github/workflows/actionlint.yml
vendored
40
.github/workflows/actionlint.yml
vendored
@ -1,40 +0,0 @@
|
||||
name: Lint GitHub Actions workflows
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '.github/workflows/*.ya?ml'
|
||||
- '.github/workflows/actionlint.*'
|
||||
- '.github/workflows/matchers/actionlint.json'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '.github/workflows/*.ya?ml'
|
||||
- '.github/workflows/actionlint.*'
|
||||
- '.github/workflows/matchers/actionlint.json'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
actionlint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Run actionlint"
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
tools/actionlint.sh -color
|
53
.github/workflows/clang-format.yml
vendored
53
.github/workflows/clang-format.yml
vendored
@ -1,53 +0,0 @@
|
||||
name: clang-format
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.h'
|
||||
- '**/*.cpp'
|
||||
- '**/*.cu'
|
||||
- '**/*.cuh'
|
||||
- '.github/workflows/clang-format.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.h'
|
||||
- '**/*.cpp'
|
||||
- '**/*.cu'
|
||||
- '**/*.cuh'
|
||||
- '.github/workflows/clang-format.yml'
|
||||
|
||||
jobs:
|
||||
clang-format:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install clang-format==18.1.5
|
||||
- name: Running clang-format
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/quantization/gguf/ggml-common.h'
|
||||
'csrc/quantization/gguf/dequantize.cuh'
|
||||
'csrc/quantization/gguf/vecdotq.cuh'
|
||||
'csrc/quantization/gguf/mmq.cuh'
|
||||
'csrc/quantization/gguf/mmvq.cuh'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
| xargs clang-format --dry-run --Werror
|
45
.github/workflows/codespell.yml
vendored
45
.github/workflows/codespell.yml
vendored
@ -1,45 +0,0 @@
|
||||
name: codespell
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "**/*.md"
|
||||
- "**/*.rst"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/codespell.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "**/*.md"
|
||||
- "**/*.rst"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/codespell.yml
|
||||
|
||||
jobs:
|
||||
codespell:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Spelling check with codespell
|
||||
run: |
|
||||
codespell --toml pyproject.toml
|
5
.github/workflows/lint-and-deploy.yaml
vendored
5
.github/workflows/lint-and-deploy.yaml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
version: v3.10.1
|
||||
|
||||
- name: Run chart-testing (lint)
|
||||
run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/chart-helm --charts examples/chart-helm
|
||||
run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm
|
||||
|
||||
- name: Setup minio
|
||||
run: |
|
||||
@ -64,7 +64,8 @@ jobs:
|
||||
run: |
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/chart-helm -f examples/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
|
17
.github/workflows/matchers/ruff.json
vendored
17
.github/workflows/matchers/ruff.json
vendored
@ -1,17 +0,0 @@
|
||||
{
|
||||
"problemMatcher": [
|
||||
{
|
||||
"owner": "ruff",
|
||||
"pattern": [
|
||||
{
|
||||
"regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$",
|
||||
"file": 1,
|
||||
"line": 2,
|
||||
"column": 3,
|
||||
"code": 4,
|
||||
"message": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
51
.github/workflows/mypy.yaml
vendored
51
.github/workflows/mypy.yaml
vendored
@ -1,51 +0,0 @@
|
||||
name: mypy
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- '**/*.py'
|
||||
- '.github/workflows/mypy.yaml'
|
||||
- 'tools/mypy.sh'
|
||||
- 'pyproject.toml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
# This workflow is only relevant when one of the following files changes.
|
||||
# However, we have github configured to expect and require this workflow
|
||||
# to run and pass before github with auto-merge a pull request. Until github
|
||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
||||
# It doesn't take that long to run, anyway.
|
||||
#paths:
|
||||
# - '**/*.py'
|
||||
# - '.github/workflows/mypy.yaml'
|
||||
# - 'tools/mypy.sh'
|
||||
# - 'pyproject.toml'
|
||||
|
||||
jobs:
|
||||
mypy:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install mypy==1.11.1
|
||||
pip install types-setuptools
|
||||
pip install types-PyYAML
|
||||
pip install types-requests
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/mypy.json"
|
||||
tools/mypy.sh 1 ${{ matrix.python-version }}
|
37
.github/workflows/png-lint.yml
vendored
37
.github/workflows/png-lint.yml
vendored
@ -1,37 +0,0 @@
|
||||
name: Lint PNG exports from excalidraw
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '*.excalidraw.png'
|
||||
- '.github/workflows/png-lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '*.excalidraw.png'
|
||||
- '.github/workflows/png-lint.yml'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
actionlint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Run png-lint.sh to check excalidraw exported images"
|
||||
run: |
|
||||
tools/png-lint.sh
|
19
.github/workflows/pre-commit.yml
vendored
Normal file
19
.github/workflows/pre-commit.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
with:
|
||||
extra_args: --all-files --hook-stage manual
|
52
.github/workflows/ruff.yml
vendored
52
.github/workflows/ruff.yml
vendored
@ -1,52 +0,0 @@
|
||||
name: ruff
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- pyproject.toml
|
||||
- requirements-lint.txt
|
||||
- .github/workflows/matchers/ruff.json
|
||||
- .github/workflows/ruff.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
# This workflow is only relevant when one of the following files changes.
|
||||
# However, we have github configured to expect and require this workflow
|
||||
# to run and pass before github with auto-merge a pull request. Until github
|
||||
# allows more flexible auto-merge policy, we can just run this on every PR.
|
||||
# It doesn't take that long to run, anyway.
|
||||
#paths:
|
||||
# - "**/*.py"
|
||||
# - pyproject.toml
|
||||
# - requirements-lint.txt
|
||||
# - .github/workflows/matchers/ruff.json
|
||||
# - .github/workflows/ruff.yml
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Analysing the code with ruff
|
||||
run: |
|
||||
echo "::add-matcher::.github/workflows/matchers/ruff.json"
|
||||
ruff check --output-format github .
|
||||
- name: Run isort
|
||||
run: |
|
||||
isort . --check-only
|
37
.github/workflows/shellcheck.yml
vendored
37
.github/workflows/shellcheck.yml
vendored
@ -1,37 +0,0 @@
|
||||
name: Lint shell scripts
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**/*.sh'
|
||||
- '.github/workflows/shellcheck.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- "main"
|
||||
paths:
|
||||
- '**/*.sh'
|
||||
- '.github/workflows/shellcheck.yml'
|
||||
|
||||
env:
|
||||
LC_ALL: en_US.UTF-8
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
shellcheck:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: "Checkout"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: "Check shell scripts"
|
||||
run: |
|
||||
tools/shellcheck.sh
|
32
.github/workflows/sphinx-lint.yml
vendored
32
.github/workflows/sphinx-lint.yml
vendored
@ -1,32 +0,0 @@
|
||||
name: Lint documentation
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/**"
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "docs/**"
|
||||
|
||||
jobs:
|
||||
sphinx-lint:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements-lint.txt
|
||||
- name: Linting docs
|
||||
run: tools/sphinx-lint.sh
|
38
.github/workflows/yapf.yml
vendored
38
.github/workflows/yapf.yml
vendored
@ -1,38 +0,0 @@
|
||||
name: yapf
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/yapf.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/yapf.yml
|
||||
|
||||
jobs:
|
||||
yapf:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install yapf==0.32.0
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive .
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -79,10 +79,7 @@ instance/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/source/getting_started/examples/*.rst
|
||||
!**/*.template.rst
|
||||
docs/source/getting_started/examples/*.md
|
||||
!**/*.template.md
|
||||
docs/source/getting_started/examples/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
|
93
.pre-commit-config.yaml
Normal file
93
.pre-commit-config.yaml
Normal file
@ -0,0 +1,93 @@
|
||||
default_stages:
|
||||
- pre-commit # Run locally
|
||||
- manual # Run in CI
|
||||
repos:
|
||||
- repo: https://github.com/google/yapf
|
||||
rev: v0.43.0
|
||||
hooks:
|
||||
- id: yapf
|
||||
args: [--in-place, --verbose]
|
||||
additional_dependencies: [toml] # TODO: Remove when yapf is upgraded
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--output-format, github]
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.4.0
|
||||
hooks:
|
||||
- id: codespell
|
||||
exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*'
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.13.2
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||
rev: v19.1.7
|
||||
hooks:
|
||||
- id: clang-format
|
||||
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))'
|
||||
types_or: [c++, cuda]
|
||||
args: [--style=file, --verbose]
|
||||
- repo: https://github.com/jackdewinter/pymarkdown
|
||||
rev: v0.9.27
|
||||
hooks:
|
||||
- id: pymarkdown
|
||||
files: docs/.*
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.7
|
||||
hooks:
|
||||
- id: actionlint
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-local
|
||||
name: Run mypy for local Python installation
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-setuptools, types-PyYAML, types-requests]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
entry: tools/mypy.sh 1 "3.9"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.10
|
||||
entry: tools/mypy.sh 1 "3.10"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.11
|
||||
entry: tools/mypy.sh 1 "3.11"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.12
|
||||
entry: tools/mypy.sh 1 "3.12"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: *mypy_deps
|
||||
stages: [manual] # Only run in CI
|
||||
- id: shellcheck
|
||||
name: Lint shell scripts
|
||||
entry: tools/shellcheck.sh
|
||||
language: script
|
||||
types: [shell]
|
||||
- id: png-lint
|
||||
name: Lint PNG exports from excalidraw
|
||||
entry: tools/png-lint.sh
|
||||
language: script
|
||||
types: [png]
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
84
CMakeLists.txt
Normal file → Executable file
84
CMakeLists.txt
Normal file → Executable file
@ -24,9 +24,6 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
# Suppress potential warnings about unused manually-specified variables
|
||||
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||
|
||||
# Prevent installation of dependencies (cutlass) by default.
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||
|
||||
#
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
@ -181,6 +178,31 @@ message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
||||
# Define other extension targets
|
||||
#
|
||||
|
||||
#
|
||||
# cumem_allocator extension
|
||||
#
|
||||
|
||||
set(VLLM_CUMEM_EXT_SRC
|
||||
"csrc/cumem_allocator.cpp")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_CUMEM_EXT_SRC}"
|
||||
CUDA_ARCHS "${CUDA_ARCHS}")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling cumem allocator extension.")
|
||||
# link against cuda driver library
|
||||
list(APPEND CUMEM_LIBS cuda)
|
||||
define_gpu_extension_target(
|
||||
cumem_allocator
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_CUMEM_EXT_SRC}
|
||||
LIBRARIES ${CUMEM_LIBS}
|
||||
USE_SABI 3.8
|
||||
WITH_SOABI)
|
||||
endif()
|
||||
|
||||
#
|
||||
# _C extension
|
||||
#
|
||||
@ -223,13 +245,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
|
||||
GIT_TAG v3.6.0
|
||||
GIT_PROGRESS TRUE
|
||||
|
||||
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
|
||||
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
|
||||
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
|
||||
GIT_SHALLOW FALSE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
endif()
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
@ -253,7 +275,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# Only build Marlin kernels if we are building for at least some compatible archs.
|
||||
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
|
||||
# are not supported by Machete yet.
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" ${CUDA_ARCHS})
|
||||
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
|
||||
if (MARLIN_ARCHS)
|
||||
set(MARLIN_SRCS
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
@ -274,8 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
|
||||
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
|
||||
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
@ -329,7 +351,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# 2:4 Sparse Kernels
|
||||
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
|
||||
# require CUDA 12.2 or later (and only work on Hopper, 9.0a for now).
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
@ -510,7 +532,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
|
||||
endif()
|
||||
|
||||
# vllm-flash-attn currently only supported on CUDA
|
||||
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
|
||||
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
return()
|
||||
endif ()
|
||||
|
||||
@ -533,7 +555,7 @@ endif()
|
||||
# They should be identical but if they aren't, this is a massive footgun.
|
||||
#
|
||||
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
|
||||
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
|
||||
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
|
||||
# If no component is specified, vllm-flash-attn is still installed.
|
||||
|
||||
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
|
||||
@ -545,43 +567,41 @@ if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
|
||||
endif()
|
||||
|
||||
if(VLLM_FLASH_ATTN_SRC_DIR)
|
||||
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn SOURCE_DIR
|
||||
${VLLM_FLASH_ATTN_SRC_DIR}
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb
|
||||
GIT_TAG d4e09037abf588af1ec47d0e966b237ee376876c
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
)
|
||||
endif()
|
||||
|
||||
# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
|
||||
set(VLLM_PARENT_BUILD ON)
|
||||
|
||||
# Ensure the vllm/vllm_flash_attn directory exists before installation
|
||||
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Make sure vllm-flash-attn install rules are nested under vllm/
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Fetch the vllm-flash-attn library
|
||||
FetchContent_MakeAvailable(vllm-flash-attn)
|
||||
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
|
||||
|
||||
# Restore the install prefix
|
||||
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
|
||||
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)
|
||||
|
||||
# Copy over the vllm-flash-attn python files
|
||||
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
|
||||
# case only one is built, in the case both are built redundant work is done)
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm/vllm_flash_attn
|
||||
COMPONENT vllm_flash_attn_c
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa2_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
install(
|
||||
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
|
||||
DESTINATION vllm_flash_attn
|
||||
COMPONENT _vllm_fa3_C
|
||||
FILES_MATCHING PATTERN "*.py"
|
||||
)
|
||||
|
||||
# Nothing after vllm-flash-attn, see comment about macros above
|
||||
|
46
Dockerfile
46
Dockerfile
@ -2,8 +2,8 @@
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
# Please update any changes made here to
|
||||
# docs/source/dev/dockerfile/dockerfile.md and
|
||||
# docs/source/assets/dev/dockerfile-stages-dependency.png
|
||||
# docs/source/contributing/dockerfile/dockerfile.md and
|
||||
# docs/source/assets/contributing/dockerfile-stages-dependency.png
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
@ -52,7 +52,7 @@ WORKDIR /workspace
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
|
||||
python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu126 "torch==2.7.0.dev20250121+cu126" "torchvision==0.22.0.dev20250121"; \
|
||||
fi
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
@ -126,8 +126,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
|
||||
|
||||
# Check the size of the wheel if RUN_WHEEL_CHECK is true
|
||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
||||
# Default max size of the wheel is 250MB
|
||||
ARG VLLM_MAX_SIZE_MB=250
|
||||
# sync the default value with .buildkite/check-wheel-size.py
|
||||
ARG VLLM_MAX_SIZE_MB=300
|
||||
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
|
||||
ARG RUN_WHEEL_CHECK=true
|
||||
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
@ -149,7 +149,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
WORKDIR /vllm-workspace
|
||||
@ -194,12 +195,30 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install dist/*.whl --verbose
|
||||
|
||||
# How to build this FlashInfer wheel:
|
||||
# $ export FLASHINFER_ENABLE_AOT=1
|
||||
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
|
||||
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
|
||||
# $ cd flashinfer
|
||||
# $ git checkout 524304395bd1d8cd7d07db083859523fcaa246a4
|
||||
# $ python3 setup.py bdist_wheel --dist-dir=dist --verbose
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
||||
python3 -m pip install https://wheels.vllm.ai/flashinfer/524304395bd1d8cd7d07db083859523fcaa246a4/flashinfer_python-0.2.0.post1-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl; \
|
||||
fi
|
||||
COPY examples examples
|
||||
|
||||
# Although we build Flashinfer with AOT mode, there's still
|
||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||
# install build dependencies for JIT compilation.
|
||||
# TODO: Remove this once FlashInfer AOT wheel is fixed
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-build.txt
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
#################### TEST IMAGE ####################
|
||||
@ -234,8 +253,8 @@ RUN mv vllm test_docs/
|
||||
#################### TEST IMAGE ####################
|
||||
|
||||
#################### OPENAI API SERVER ####################
|
||||
# openai api server alternative
|
||||
FROM vllm-base AS vllm-openai
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
@ -247,5 +266,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
# define sagemaker first, so it is not default from `docker build`
|
||||
FROM vllm-openai-base AS vllm-sagemaker
|
||||
|
||||
COPY examples/online_serving/sagemaker-entrypoint.sh .
|
||||
RUN chmod +x sagemaker-entrypoint.sh
|
||||
ENTRYPOINT ["./sagemaker-entrypoint.sh"]
|
||||
|
||||
FROM vllm-openai-base AS vllm-openai
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
#################### OPENAI API SERVER ####################
|
||||
|
@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
|
||||
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
|
||||
pip install --upgrade pip && \
|
||||
pip install -r requirements-build.txt
|
||||
|
||||
@ -37,9 +37,9 @@ FROM cpu-test-1 AS build
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-cpu.txt requirements-cpu.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
|
||||
--mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
|
||||
pip install -v -r requirements-cpu.txt
|
||||
|
||||
COPY . .
|
||||
|
@ -1,4 +1,4 @@
|
||||
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest
|
||||
FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# default base image
|
||||
# https://gallery.ecr.aws/neuron/pytorch-inference-neuronx
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.2-ubuntu20.04"
|
||||
ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.5.1-neuronx-py310-sdk2.21.0-ubuntu22.04"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
||||
@ -15,16 +15,17 @@ RUN apt-get update && \
|
||||
ffmpeg libsm6 libxext6 libgl1
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
ARG APP_MOUNT=/app
|
||||
# When launching the container, mount the code directory to /workspace
|
||||
ARG APP_MOUNT=/workspace
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
|
||||
RUN python3 -m pip install sentencepiece transformers==4.45.2 -U
|
||||
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install neuronx-cc==2.16.345.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
||||
RUN python3 -m pip install pytest
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
@ -42,4 +43,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
# overwrite entrypoint to run bash script
|
||||
RUN echo "import subprocess; import sys; subprocess.check_call(sys.argv[1:])" > /usr/local/bin/dockerd-entrypoint.py
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
@ -14,6 +14,7 @@ ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
RUN python3 -m pip install -U pip
|
||||
# install build requirements
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt
|
||||
# build vLLM with OpenVINO backend
|
||||
|
@ -4,7 +4,7 @@ USER root
|
||||
|
||||
ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/"
|
||||
|
||||
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1
|
||||
RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 libssl-dev
|
||||
|
||||
# Some packages in requirements-cpu are installed here
|
||||
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
|
||||
@ -18,9 +18,8 @@ ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
# These packages will be in rocketce eventually
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
||||
RUSTFLAGS='-L /opt/conda/lib' pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \
|
||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
||||
torch==2.3.1 \
|
||||
-r requirements-cpu.txt \
|
||||
|
261
Dockerfile.rocm
261
Dockerfile.rocm
@ -1,174 +1,119 @@
|
||||
# Default ROCm 6.2 base image
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0"
|
||||
# default base image
|
||||
ARG REMOTE_VLLM="0"
|
||||
ARG USE_CYTHON="0"
|
||||
ARG BUILD_RPD="1"
|
||||
ARG COMMON_WORKDIR=/app
|
||||
ARG BASE_IMAGE=rocm/vllm-dev:base
|
||||
|
||||
# Default ROCm ARCHes to build vLLM for.
|
||||
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
# Whether to install CK-based flash-attention
|
||||
# If 0, will not install flash-attention
|
||||
ARG BUILD_FA="1"
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
ARG FA_BRANCH="3cea2fb"
|
||||
|
||||
# Whether to build triton on rocm
|
||||
ARG BUILD_TRITON="1"
|
||||
ARG TRITON_BRANCH="e192dba"
|
||||
|
||||
### Base image build stage
|
||||
FROM $BASE_IMAGE AS base
|
||||
|
||||
# Import arg(s) defined before this build stage
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG ARG_PYTORCH_ROCM_ARCH
|
||||
ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}}
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
tmux \
|
||||
ccache \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# When launching the container, mount the code directory to /vllm-workspace
|
||||
ARG APP_MOUNT=/vllm-workspace
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
# Remove sccache so it doesn't interfere with ccache
|
||||
# TODO: implement sccache support across components
|
||||
RUN apt-get update -q -y && apt-get install -q -y \
|
||||
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
|
||||
# Remove sccache
|
||||
RUN python3 -m pip install --upgrade pip && pip install setuptools_scm
|
||||
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
|
||||
# Install torch == 2.6.0 on ROCm
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-6.2"*) \
|
||||
python3 -m pip uninstall -y torch torchvision \
|
||||
&& python3 -m pip install --pre \
|
||||
torch==2.6.0.dev20241113+rocm6.2 \
|
||||
'setuptools-scm>=8' \
|
||||
torchvision==0.20.0.dev20241113+rocm6.2 \
|
||||
--extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \
|
||||
*) ;; esac
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
ENV CCACHE_DIR=/root/.cache/ccache
|
||||
ARG COMMON_WORKDIR
|
||||
WORKDIR ${COMMON_WORKDIR}
|
||||
|
||||
|
||||
### AMD-SMI build stage
|
||||
FROM base AS build_amdsmi
|
||||
# Build amdsmi wheel always
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& python3 -m pip wheel . --wheel-dir=/install
|
||||
# -----------------------
|
||||
# vLLM fetch stages
|
||||
FROM base AS fetch_vllm_0
|
||||
ONBUILD COPY ./ vllm/
|
||||
FROM base AS fetch_vllm_1
|
||||
ARG VLLM_REPO="https://github.com/vllm-project/vllm.git"
|
||||
ARG VLLM_BRANCH="main"
|
||||
ONBUILD RUN git clone ${VLLM_REPO} \
|
||||
&& cd vllm \
|
||||
&& git checkout ${VLLM_BRANCH}
|
||||
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm
|
||||
|
||||
# -----------------------
|
||||
# vLLM build stages
|
||||
FROM fetch_vllm AS build_vllm
|
||||
ARG USE_CYTHON
|
||||
# Build vLLM
|
||||
RUN cd vllm \
|
||||
&& python3 -m pip install -r requirements-rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
FROM scratch AS export_vllm
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/dist/*.whl /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/requirements*.txt /
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/benchmarks /benchmarks
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/tests /tests
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/examples /examples
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm/.buildkite /.buildkite
|
||||
|
||||
### Flash-Attention wheel build stage
|
||||
FROM base AS build_fa
|
||||
ARG BUILD_FA
|
||||
ARG FA_GFX_ARCHS
|
||||
ARG FA_BRANCH
|
||||
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout "${FA_BRANCH}" \
|
||||
&& git submodule update --init \
|
||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
# -----------------------
|
||||
# Test vLLM image
|
||||
FROM base AS test
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Triton wheel build stage
|
||||
FROM base AS build_triton
|
||||
ARG BUILD_TRITON
|
||||
ARG TRITON_BRANCH
|
||||
# Build triton wheel if `BUILD_TRITON = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_TRITON" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& python3 -m pip install ninja cmake wheel pybind11 \
|
||||
&& git clone https://github.com/OpenAI/triton.git \
|
||||
&& cd triton \
|
||||
&& git checkout "${TRITON_BRANCH}" \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
|
||||
### Final vLLM build stage
|
||||
FROM base AS final
|
||||
# Import the vLLM development directory from the build context
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
|
||||
# Package upgrades for useful functionality or to avoid dependency issues
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard
|
||||
|
||||
|
||||
# Workaround for ray >= 2.10.0
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
# Silences the HF Tokenizers warning
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -Ur requirements-rocm.txt \
|
||||
&& python3 setup.py clean --all \
|
||||
&& python3 setup.py develop
|
||||
|
||||
# Copy amdsmi wheel into final image
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y amdsmi;
|
||||
|
||||
# Copy triton wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y triton; fi
|
||||
|
||||
# Copy flash-attn wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& python3 -m pip uninstall -y flash-attn; fi
|
||||
|
||||
# Install wheels that were built to the final image
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if ls libs/*.whl; then \
|
||||
python3 -m pip install libs/*.whl; fi
|
||||
WORKDIR /vllm-workspace
|
||||
ARG COMMON_WORKDIR
|
||||
COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
RUN cd /vllm-workspace \
|
||||
&& rm -rf vllm \
|
||||
&& python3 -m pip install -e tests/vllm_test_utils \
|
||||
&& python3 -m pip install lm-eval[api]==0.4.4 \
|
||||
&& python3 -m pip install pytest-shard
|
||||
|
||||
# -----------------------
|
||||
# Final vLLM image
|
||||
FROM base AS final
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually remove it so that later steps of numpy upgrade can continue
|
||||
RUN case "$(which python3)" in \
|
||||
*"/opt/conda/envs/py_3.9"*) \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
|
||||
*) ;; esac
|
||||
|
||||
RUN python3 -m pip install --upgrade huggingface-hub[cli]
|
||||
ARG BUILD_RPD
|
||||
RUN if [ ${BUILD_RPD} -eq "1" ]; then \
|
||||
git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git \
|
||||
&& cd rocmProfileData/rpd_tracer \
|
||||
&& pip install -r requirements.txt && cd ../ \
|
||||
&& make && make install \
|
||||
&& cd hipMarker && python3 setup.py install ; fi
|
||||
|
||||
# Install vLLM
|
||||
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
|
||||
cd /install \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& pip uninstall -y vllm \
|
||||
&& pip install *.whl
|
||||
|
||||
ARG COMMON_WORKDIR
|
||||
|
||||
# Copy over the benchmark scripts as well
|
||||
COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks
|
||||
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
|
||||
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
# Performance environment variable.
|
||||
ENV HIP_FORCE_DEV_KERNARG=1
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
|
158
Dockerfile.rocm_base
Normal file
158
Dockerfile.rocm_base
Normal file
@ -0,0 +1,158 @@
|
||||
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete
|
||||
ARG HIPBLASLT_BRANCH="4d40e36"
|
||||
ARG HIPBLAS_COMMON_BRANCH="7c1566b"
|
||||
ARG LEGACY_HIPBLASLT_OPTION=
|
||||
ARG RCCL_BRANCH="648a58d"
|
||||
ARG RCCL_REPO="https://github.com/ROCm/rccl"
|
||||
ARG TRITON_BRANCH="e5be006"
|
||||
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
|
||||
ARG PYTORCH_BRANCH="8d4926e"
|
||||
ARG PYTORCH_VISION_BRANCH="v0.19.1"
|
||||
ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="b7d29fb"
|
||||
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
ENV PATH=/opt/rocm/llvm/bin:$PATH
|
||||
ENV ROCM_PATH=/opt/rocm
|
||||
ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib:
|
||||
ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942
|
||||
ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
RUN mkdir -p /app
|
||||
WORKDIR /app
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y software-properties-common git curl sudo vim less \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-lib2to3 python-is-python3 \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython
|
||||
|
||||
FROM base AS build_hipblaslt
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG HIPBLAS_COMMON_BRANCH
|
||||
# Set to "--legacy_hipblas_direct" for ROCm<=6.2
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
RUN git clone https://github.com/ROCm/hipBLAS-common.git
|
||||
RUN cd hipBLAS-common \
|
||||
&& git checkout ${HIPBLAS_COMMON_BRANCH} \
|
||||
&& mkdir build \
|
||||
&& cd build \
|
||||
&& cmake .. \
|
||||
&& make package \
|
||||
&& dpkg -i ./*.deb
|
||||
RUN git clone https://github.com/ROCm/hipBLASLt
|
||||
RUN cd hipBLASLt \
|
||||
&& git checkout ${HIPBLASLT_BRANCH} \
|
||||
&& ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \
|
||||
&& cd build/release \
|
||||
&& make package
|
||||
RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install
|
||||
|
||||
FROM base AS build_rccl
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
RUN git clone ${RCCL_REPO}
|
||||
RUN cd rccl \
|
||||
&& git checkout ${RCCL_BRANCH} \
|
||||
&& ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH}
|
||||
RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install
|
||||
|
||||
FROM base AS build_triton
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
RUN git clone ${TRITON_REPO}
|
||||
RUN cd triton \
|
||||
&& git checkout ${TRITON_BRANCH} \
|
||||
&& cd python \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_amdsmi
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=dist
|
||||
RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install
|
||||
|
||||
FROM base AS build_pytorch
|
||||
ARG PYTORCH_BRANCH
|
||||
ARG PYTORCH_VISION_BRANCH
|
||||
ARG PYTORCH_REPO
|
||||
ARG PYTORCH_VISION_REPO
|
||||
ARG FA_BRANCH
|
||||
ARG FA_REPO
|
||||
RUN git clone ${PYTORCH_REPO} pytorch
|
||||
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
|
||||
pip install -r requirements.txt && git submodule update --init --recursive \
|
||||
&& python3 tools/amd_build/build_amd.py \
|
||||
&& CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& pip install dist/*.whl
|
||||
RUN git clone ${PYTORCH_VISION_REPO} vision
|
||||
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
|
||||
&& python3 setup.py bdist_wheel --dist-dir=dist \
|
||||
&& pip install dist/*.whl
|
||||
RUN git clone ${FA_REPO}
|
||||
RUN cd flash-attention \
|
||||
&& git checkout ${FA_BRANCH} \
|
||||
&& git submodule update --init \
|
||||
&& MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
|
||||
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
|
||||
&& cp /app/vision/dist/*.whl /app/install \
|
||||
&& cp /app/flash-attention/dist/*.whl /app/install
|
||||
|
||||
FROM base AS final
|
||||
RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \
|
||||
dpkg -i /install/*deb \
|
||||
&& sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \
|
||||
&& sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status
|
||||
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
|
||||
pip install /install/*.whl
|
||||
|
||||
ARG BASE_IMAGE
|
||||
ARG HIPBLASLT_BRANCH
|
||||
ARG LEGACY_HIPBLASLT_OPTION
|
||||
ARG RCCL_BRANCH
|
||||
ARG RCCL_REPO
|
||||
ARG TRITON_BRANCH
|
||||
ARG TRITON_REPO
|
||||
ARG PYTORCH_BRANCH
|
||||
ARG PYTORCH_VISION_BRANCH
|
||||
ARG PYTORCH_REPO
|
||||
ARG PYTORCH_VISION_REPO
|
||||
ARG FA_BRANCH
|
||||
ARG FA_REPO
|
||||
RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
|
||||
&& echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \
|
||||
&& echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \
|
||||
&& echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \
|
||||
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
|
||||
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
|
||||
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
|
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20241017"
|
||||
ARG NIGHTLY_DATE="20250124"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
|
33
README.md
33
README.md
@ -16,6 +16,8 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing).
|
||||
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||
@ -34,10 +36,12 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
## About
|
||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evloved into a community-driven project with contributions from both academia and industry.
|
||||
|
||||
vLLM is fast with:
|
||||
|
||||
- State-of-the-art serving throughput
|
||||
- Efficient management of attention key and value memory with **PagedAttention**
|
||||
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
|
||||
- Continuous batching of incoming requests
|
||||
- Fast model execution with CUDA/HIP graph
|
||||
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8.
|
||||
@ -60,7 +64,7 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||
- Transformer-like LLMs (e.g., Llama)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral)
|
||||
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
|
||||
- Embedding Models (e.g. E5-Mistral)
|
||||
- Multi-modal LLMs (e.g., LLaVA)
|
||||
|
||||
@ -68,16 +72,16 @@ Find the full list of supported models [here](https://docs.vllm.ai/en/latest/mod
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
||||
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
|
||||
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
|
||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
||||
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation/index.html)
|
||||
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
|
||||
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
|
||||
|
||||
## Contributing
|
||||
|
||||
@ -90,28 +94,33 @@ vLLM is a community project. Our compute resources for development and testing a
|
||||
|
||||
<!-- Note: Please sort them in alphabetical order. -->
|
||||
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
|
||||
|
||||
Cash Donations:
|
||||
- a16z
|
||||
- Dropbox
|
||||
- Sequoia Capital
|
||||
- Skywork AI
|
||||
- ZhenFund
|
||||
|
||||
Compute Resources:
|
||||
- AMD
|
||||
- Anyscale
|
||||
- AWS
|
||||
- Crusoe Cloud
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Dropbox
|
||||
- Google Cloud
|
||||
- Lambda Lab
|
||||
- Nebius
|
||||
- Novita AI
|
||||
- NVIDIA
|
||||
- Replicate
|
||||
- Roblox
|
||||
- RunPod
|
||||
- Sequoia Capital
|
||||
- Skywork AI
|
||||
- Trainy
|
||||
- UC Berkeley
|
||||
- UC San Diego
|
||||
- ZhenFund
|
||||
|
||||
Slack Sponsor: Anyscale
|
||||
|
||||
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new).
|
||||
Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html).
|
||||
|
||||
---
|
||||
|
||||
|
@ -22,6 +22,7 @@ class RequestFuncInput:
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
best_of: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
@ -34,6 +35,7 @@ class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0.0
|
||||
output_tokens: int = 0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
@ -49,7 +51,8 @@ async def async_request_tgi(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
@ -78,7 +81,7 @@ async def async_request_tgi(
|
||||
continue
|
||||
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||
|
||||
#NOTE: Sometimes TGI returns a ping response without
|
||||
# NOTE: Sometimes TGI returns a ping response without
|
||||
# any data, we should skip it.
|
||||
if chunk_bytes.startswith(":"):
|
||||
continue
|
||||
@ -121,7 +124,8 @@ async def async_request_trt_llm(
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith("generate_stream")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
@ -155,7 +159,7 @@ async def async_request_trt_llm(
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
@ -185,7 +189,8 @@ async def async_request_deepspeed_mii(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
|
||||
payload = {
|
||||
@ -233,17 +238,23 @@ async def async_request_openai_completions(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
"ignore_eos": request_func_input.ignore_eos,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
@ -254,7 +265,6 @@ async def async_request_openai_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
@ -269,15 +279,16 @@ async def async_request_openai_completions(
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
if chunk != "[DONE]":
|
||||
data = json.loads(chunk)
|
||||
|
||||
# NOTE: Some completion API might have a last
|
||||
# usage summary response without a token so we
|
||||
# want to check a token was generated
|
||||
if data["choices"][0]["text"]:
|
||||
if choices := data.get("choices"):
|
||||
# Note that text could be empty here
|
||||
# e.g. for special tokens
|
||||
text = choices[0].get("text")
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if not first_chunk_received:
|
||||
@ -291,7 +302,10 @@ async def async_request_openai_completions(
|
||||
most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
generated_text += text or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
if first_chunk_received:
|
||||
output.success = True
|
||||
else:
|
||||
@ -300,7 +314,7 @@ async def async_request_openai_completions(
|
||||
"Never received a valid chunk to calculate TTFT."
|
||||
"This response will be marked as failed!")
|
||||
output.generated_text = generated_text
|
||||
output.latency = latency
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@ -323,12 +337,14 @@ async def async_request_openai_chat_completions(
|
||||
"chat/completions"
|
||||
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@ -338,8 +354,12 @@ async def async_request_openai_chat_completions(
|
||||
"temperature": 0.0,
|
||||
"max_completion_tokens": request_func_input.output_len,
|
||||
"stream": True,
|
||||
"ignore_eos": request_func_input.ignore_eos,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if request_func_input.ignore_eos:
|
||||
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||
if request_func_input.extra_body:
|
||||
payload.update(request_func_input.extra_body)
|
||||
headers = {
|
||||
@ -365,17 +385,15 @@ async def async_request_openai_chat_completions(
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix(
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
delta = data["choices"][0]["delta"]
|
||||
if delta.get("content", None):
|
||||
if choices := data.get("choices"):
|
||||
content = choices[0]["delta"].get("content")
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
ttft = timestamp - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
@ -383,13 +401,16 @@ async def async_request_openai_chat_completions(
|
||||
output.itl.append(timestamp -
|
||||
most_recent_timestamp)
|
||||
|
||||
generated_text += delta["content"]
|
||||
generated_text += content or ""
|
||||
elif usage := data.get("usage"):
|
||||
output.output_tokens = usage.get(
|
||||
"completion_tokens")
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.latency = most_recent_timestamp - st
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@ -417,14 +438,35 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
pretrained_model_name_or_path: str, trust_remote_code: bool
|
||||
pretrained_model_name_or_path: str,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||
pretrained_model_name_or_path):
|
||||
pretrained_model_name_or_path = get_model(
|
||||
pretrained_model_name_or_path)
|
||||
return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
|
||||
trust_remote_code=trust_remote_code)
|
||||
if tokenizer_mode == "slow":
|
||||
if kwargs.get("use_fast", False):
|
||||
raise ValueError(
|
||||
"Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||
kwargs["use_fast"] = False
|
||||
if tokenizer_mode == "mistral":
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
except ImportError as e:
|
||||
raise ImportError("MistralTokenizer requires vllm package.\n"
|
||||
"Please install it with `pip install vllm` "
|
||||
"to use mistral tokenizer mode.") from e
|
||||
return MistralTokenizer.from_pretrained(
|
||||
str(pretrained_model_name_or_path))
|
||||
else:
|
||||
return AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
|
@ -13,6 +13,7 @@ from tqdm import tqdm
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -40,6 +41,20 @@ def main(args: argparse.Namespace):
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
def llm_generate():
|
||||
if not args.use_beam_search:
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
else:
|
||||
llm.beam_search(
|
||||
dummy_prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=args.n,
|
||||
max_tokens=args.output_len,
|
||||
ignore_eos=True,
|
||||
))
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
@ -49,15 +64,11 @@ def main(args: argparse.Namespace):
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print(p.key_averages())
|
||||
llm_generate()
|
||||
print(p.key_averages().table(sort_by="self_cuda_time_total"))
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(dummy_prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
llm_generate()
|
||||
end_time = time.perf_counter()
|
||||
latency = end_time - start_time
|
||||
return latency
|
||||
|
183
benchmarks/benchmark_long_document_qa_throughput.py
Normal file
183
benchmarks/benchmark_long_document_qa_throughput.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""
|
||||
Offline benchmark to test the long document QA throughput.
|
||||
|
||||
Example usage:
|
||||
# This workload samples 8 different prompts with a default input
|
||||
# length of 20000 tokens, then replicates each prompt 2 times
|
||||
# in random order.
|
||||
python benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--repeat-count 2
|
||||
|
||||
Commandline arguments:
|
||||
--num-documents: The number of documents to sample prompts from.
|
||||
|
||||
--document-length: The length of each document in tokens.
|
||||
(Optional, default: 20000)
|
||||
|
||||
--output-len: The number of tokens to generate for each prompt.
|
||||
(Optional, default: 10)
|
||||
|
||||
--repeat-count: The number of times to repeat each prompt.
|
||||
(Optional, default: 2)
|
||||
|
||||
--repeat-mode: The mode to repeat prompts. The supported modes are:
|
||||
- 'random': shuffle the prompts randomly. (Default)
|
||||
- 'tile': the entire prompt list is repeated in sequence. (Potentially
|
||||
lowest cache hit)
|
||||
- 'interleave': each prompt is repeated consecutively before
|
||||
moving to the next element. (Highest cache hit)
|
||||
|
||||
--shuffle-seed: Random seed when the repeat mode is "random".
|
||||
(Optional, default: 0)
|
||||
|
||||
In the meantime, it also supports all the vLLM engine args to initialize the
|
||||
LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
|
||||
details.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import random
|
||||
import time
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
|
||||
"""
|
||||
Test long document QA with the given prompts and sampling parameters.
|
||||
Print the time spent in processing all the prompts.
|
||||
|
||||
Args:
|
||||
llm: The language model used for generating responses.
|
||||
sampling_params: Sampling parameter used to generate the response.
|
||||
prompts: A list of prompt strings to be processed by the LLM.
|
||||
"""
|
||||
start_time = time.time()
|
||||
llm.generate(prompts, sampling_params=sampling_params)
|
||||
end_time = time.time()
|
||||
print(f"Time to execute all requests: {end_time - start_time:.4f} secs")
|
||||
|
||||
|
||||
def repeat_prompts(prompts, repeat_count, mode: str):
|
||||
"""
|
||||
Repeat each prompt in the list for a specified number of times.
|
||||
The order of prompts in the output list depends on the mode.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts to be repeated.
|
||||
repeat_count: The number of times each prompt is repeated.
|
||||
mode: The mode of repetition. Supported modes are:
|
||||
- 'random': Shuffle the prompts randomly after repetition.
|
||||
- 'tile': Repeat the entire prompt list in sequence.
|
||||
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
||||
- 'interleave': Repeat each prompt consecutively before moving to
|
||||
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
||||
|
||||
Returns:
|
||||
A list of repeated prompts in the specified order.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid mode is provided.
|
||||
"""
|
||||
print("Repeat mode: ", mode)
|
||||
if mode == 'random':
|
||||
repeated_prompts = prompts * repeat_count
|
||||
random.shuffle(repeated_prompts)
|
||||
return repeated_prompts
|
||||
elif mode == 'tile':
|
||||
return prompts * repeat_count
|
||||
elif mode == 'interleave':
|
||||
repeated_prompts = []
|
||||
for prompt in prompts:
|
||||
repeated_prompts.extend([prompt] * repeat_count)
|
||||
return repeated_prompts
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}, only support "
|
||||
"'random', 'tile', 'interleave'")
|
||||
|
||||
|
||||
def main(args):
|
||||
random.seed(args.shuffle_seed)
|
||||
|
||||
# Prepare the prompts:
|
||||
# we append the document id at the beginning to avoid any of the document
|
||||
# being the prefix of other documents
|
||||
prompts = [
|
||||
str(i) + ' '.join(['hi'] * args.document_length)
|
||||
for i in range(args.num_documents)
|
||||
]
|
||||
|
||||
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
|
||||
|
||||
warmup_prompts = [
|
||||
"This is warm up request " + str(i) + \
|
||||
' '.join(['hi'] * args.document_length)
|
||||
for i in range(args.num_documents)]
|
||||
|
||||
# Create the LLM engine
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||
|
||||
print("------warm up------")
|
||||
test_long_document_qa(
|
||||
llm=llm,
|
||||
prompts=warmup_prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("------start generating------")
|
||||
test_long_document_qa(
|
||||
llm=llm,
|
||||
prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description=
|
||||
'Benchmark the performance with or without automatic prefix caching.')
|
||||
|
||||
parser.add_argument(
|
||||
'--document-length',
|
||||
type=int,
|
||||
# Roughly the number of tokens for a system paper,
|
||||
# excluding images
|
||||
default=20000,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
|
||||
parser.add_argument('--num-documents',
|
||||
type=int,
|
||||
default=8,
|
||||
help='Range of input lengths for sampling prompts,'
|
||||
'specified as "min:max" (e.g., "128:256").')
|
||||
|
||||
parser.add_argument('--output-len', type=int, default=10)
|
||||
|
||||
parser.add_argument('--repeat-count',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Number of times to repeat each prompt')
|
||||
|
||||
parser.add_argument("--repeat-mode",
|
||||
type=str,
|
||||
default='random',
|
||||
help='The mode to repeat prompts. The supported '
|
||||
'modes are "random", "tile", and "interleave". '
|
||||
'See repeat_prompts() in the source code for details.')
|
||||
|
||||
parser.add_argument("--shuffle-seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help='Random seed when the repeat mode is "random"')
|
||||
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
@ -10,7 +10,8 @@ Fixed example usage:
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 1 \
|
||||
--repeat-count 100
|
||||
--repeat-count 100 \
|
||||
--input-length-range 128:256
|
||||
|
||||
ShareGPT example usage:
|
||||
# This command samples 20 prompts with input lengths
|
||||
|
@ -25,6 +25,7 @@ On the client side, run:
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
@ -199,7 +200,7 @@ def sample_sonnet_requests(
|
||||
return sampled_requests
|
||||
|
||||
|
||||
def sample_mmmu_pro_vision_requests(
|
||||
def sample_vision_arena_requests(
|
||||
dataset,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -211,13 +212,7 @@ def sample_mmmu_pro_vision_requests(
|
||||
if len(sampled_requests) == num_requests:
|
||||
break
|
||||
|
||||
# MMMU-Pro vision direct prompt
|
||||
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
|
||||
prompt = (
|
||||
"Answer with the option letter from the given choices directly. "
|
||||
"The last line of your response should be of the following "
|
||||
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
|
||||
"options.")
|
||||
prompt = data["turns"][0][0]['content']
|
||||
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
if fixed_output_len is None:
|
||||
@ -229,10 +224,10 @@ def sample_mmmu_pro_vision_requests(
|
||||
output_len = fixed_output_len
|
||||
|
||||
assert isinstance(
|
||||
data["image"],
|
||||
data["images"][0],
|
||||
Image), ("Input image format must be `PIL.Image.Image`, "
|
||||
f"given {type(data['image'])}.")
|
||||
image: Image = data["image"]
|
||||
image: Image = data["images"][0]
|
||||
image = image.convert("RGB")
|
||||
image_data = io.BytesIO()
|
||||
image.save(image_data, format='JPEG')
|
||||
@ -251,7 +246,7 @@ def sample_mmmu_pro_vision_requests(
|
||||
|
||||
def sample_hf_requests(
|
||||
dataset_path: str,
|
||||
dataset_subset: str,
|
||||
dataset_subset: Optional[str],
|
||||
dataset_split: str,
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
@ -259,19 +254,17 @@ def sample_hf_requests(
|
||||
fixed_output_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:
|
||||
|
||||
# Special case for MMMU-Pro vision dataset
|
||||
if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision':
|
||||
assert dataset_split == "test"
|
||||
# Special case for vision_arena dataset
|
||||
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
|
||||
and dataset_subset is None:
|
||||
assert dataset_split == "train"
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
split=dataset_split,
|
||||
streaming=True)
|
||||
assert "image" in dataset.features, (
|
||||
"MMMU/MMMU_Pro vision dataset must have 'image' column.")
|
||||
filter_func = lambda x: isinstance(x["image"], Image)
|
||||
dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
|
||||
return sample_mmmu_pro_vision_requests(dataset, num_requests,
|
||||
tokenizer, fixed_output_len)
|
||||
dataset = dataset.shuffle(seed=random_seed)
|
||||
return sample_vision_arena_requests(dataset, num_requests, tokenizer,
|
||||
fixed_output_len)
|
||||
|
||||
dataset = load_dataset(dataset_path,
|
||||
name=dataset_subset,
|
||||
@ -423,7 +416,7 @@ def calculate_metrics(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[float],
|
||||
gootput_config_dict: Dict[str, float],
|
||||
goodput_config_dict: Dict[str, float],
|
||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
||||
actual_output_lens: List[int] = []
|
||||
total_input = 0
|
||||
@ -436,19 +429,23 @@ def calculate_metrics(
|
||||
e2els: List[float] = []
|
||||
for i in range(len(outputs)):
|
||||
if outputs[i].success:
|
||||
# We use the tokenizer to count the number of output tokens for all
|
||||
# serving backends instead of looking at len(outputs[i].itl) since
|
||||
# multiple output tokens may be bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
output_len = outputs[i].output_tokens
|
||||
|
||||
if output_len is None:
|
||||
# We use the tokenizer to count the number of output tokens
|
||||
# for some serving backends instead of looking at
|
||||
# len(outputs[i].itl) since multiple output tokens may be
|
||||
# bundled together
|
||||
# Note : this may inflate the output token count slightly
|
||||
output_len = len(
|
||||
tokenizer(outputs[i].generated_text,
|
||||
add_special_tokens=False).input_ids)
|
||||
actual_output_lens.append(output_len)
|
||||
total_input += input_requests[i][1]
|
||||
tpot = 0
|
||||
if output_len > 1:
|
||||
tpot = (outputs[i].latency - outputs[i].ttft) / (output_len -
|
||||
1)
|
||||
latency_minus_ttft = outputs[i].latency - outputs[i].ttft
|
||||
tpot = latency_minus_ttft / (output_len - 1)
|
||||
tpots.append(tpot)
|
||||
# Note: if output_len <= 1, we regard tpot as 0 for goodput
|
||||
all_tpots.append(tpot)
|
||||
@ -459,21 +456,21 @@ def calculate_metrics(
|
||||
else:
|
||||
actual_output_lens.append(0)
|
||||
|
||||
if gootput_config_dict:
|
||||
if goodput_config_dict:
|
||||
valid_metrics = []
|
||||
slo_values = []
|
||||
|
||||
if "ttft" in gootput_config_dict:
|
||||
if "ttft" in goodput_config_dict:
|
||||
valid_metrics.append(ttfts)
|
||||
slo_values.append(gootput_config_dict["ttft"] /
|
||||
slo_values.append(goodput_config_dict["ttft"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "tpot" in gootput_config_dict:
|
||||
if "tpot" in goodput_config_dict:
|
||||
valid_metrics.append(all_tpots)
|
||||
slo_values.append(gootput_config_dict["tpot"] /
|
||||
slo_values.append(goodput_config_dict["tpot"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
if "e2el" in gootput_config_dict:
|
||||
if "e2el" in goodput_config_dict:
|
||||
valid_metrics.append(e2els)
|
||||
slo_values.append(gootput_config_dict["e2el"] /
|
||||
slo_values.append(goodput_config_dict["e2el"] /
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION)
|
||||
|
||||
for req_metric in zip(*valid_metrics):
|
||||
@ -525,6 +522,7 @@ async def benchmark(
|
||||
api_url: str,
|
||||
base_url: str,
|
||||
model_id: str,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
logprobs: Optional[int],
|
||||
@ -536,7 +534,7 @@ async def benchmark(
|
||||
selected_percentile_metrics: List[str],
|
||||
selected_percentiles: List[str],
|
||||
ignore_eos: bool,
|
||||
gootput_config_dict: Dict[str, float],
|
||||
goodput_config_dict: Dict[str, float],
|
||||
max_concurrency: Optional[int],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@ -553,6 +551,7 @@ async def benchmark(
|
||||
"Multi-modal content is only supported on 'openai-chat' backend.")
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=test_prompt_len,
|
||||
@ -573,6 +572,7 @@ async def benchmark(
|
||||
if profile:
|
||||
print("Starting profiler...")
|
||||
profile_input = RequestFuncInput(model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=test_prompt,
|
||||
api_url=base_url + "/start_profile",
|
||||
prompt_len=test_prompt_len,
|
||||
@ -616,6 +616,7 @@ async def benchmark(
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
prompt, prompt_len, output_len, mm_content = request
|
||||
request_func_input = RequestFuncInput(model=model_id,
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
@ -657,7 +658,7 @@ async def benchmark(
|
||||
tokenizer=tokenizer,
|
||||
selected_percentile_metrics=selected_percentile_metrics,
|
||||
selected_percentiles=selected_percentiles,
|
||||
gootput_config_dict=gootput_config_dict,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
)
|
||||
|
||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
||||
@ -669,7 +670,7 @@ async def benchmark(
|
||||
metrics.total_output))
|
||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
||||
metrics.request_throughput))
|
||||
if gootput_config_dict:
|
||||
if goodput_config_dict:
|
||||
print("{:<40} {:<10.2f}".format("Request goodput (req/s):",
|
||||
metrics.request_goodput))
|
||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
||||
@ -684,7 +685,7 @@ async def benchmark(
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:":
|
||||
metrics.request_goodput if gootput_config_dict else None,
|
||||
metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -740,11 +741,11 @@ async def benchmark(
|
||||
|
||||
def check_goodput_args(args):
|
||||
# Check and parse goodput arguments
|
||||
gootput_config_dict = {}
|
||||
goodput_config_dict = {}
|
||||
VALID_NAMES = ["ttft", "tpot", "e2el"]
|
||||
if args.goodput:
|
||||
gootput_config_dict = parse_goodput(args.goodput)
|
||||
for slo_name, slo_val in gootput_config_dict.items():
|
||||
goodput_config_dict = parse_goodput(args.goodput)
|
||||
for slo_name, slo_val in goodput_config_dict.items():
|
||||
if slo_name not in VALID_NAMES:
|
||||
raise ValueError(
|
||||
f"Invalid metric name found, {slo_name}: {slo_val}. "
|
||||
@ -755,22 +756,22 @@ def check_goodput_args(args):
|
||||
f"Invalid value found, {slo_name}: {slo_val}. "
|
||||
"The service level objective value should be "
|
||||
"non-negative.")
|
||||
return gootput_config_dict
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def parse_goodput(slo_pairs):
|
||||
gootput_config_dict = {}
|
||||
goodput_config_dict = {}
|
||||
try:
|
||||
for slo_pair in slo_pairs:
|
||||
slo_name, slo_val = slo_pair.split(":")
|
||||
gootput_config_dict[slo_name] = float(slo_val)
|
||||
goodput_config_dict[slo_name] = float(slo_val)
|
||||
except ValueError as err:
|
||||
raise argparse.ArgumentTypeError(
|
||||
"Invalid format found for service level objectives. "
|
||||
"Specify service level objectives for goodput as \"KEY:VALUE\" "
|
||||
"pairs, where the key is a metric name, and the value is a "
|
||||
"number in milliseconds.") from err
|
||||
return gootput_config_dict
|
||||
return goodput_config_dict
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -780,6 +781,7 @@ def main(args: argparse.Namespace):
|
||||
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
model_name = args.served_model_name
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
tokenizer_mode = args.tokenizer_mode
|
||||
|
||||
@ -869,7 +871,11 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
gootput_config_dict = check_goodput_args(args)
|
||||
goodput_config_dict = check_goodput_args(args)
|
||||
|
||||
# Avoid GC processing "static" data - reduce pause times.
|
||||
gc.collect()
|
||||
gc.freeze()
|
||||
|
||||
benchmark_result = asyncio.run(
|
||||
benchmark(
|
||||
@ -877,6 +883,7 @@ def main(args: argparse.Namespace):
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
model_name=model_name,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
@ -890,7 +897,7 @@ def main(args: argparse.Namespace):
|
||||
float(p) for p in args.metric_percentiles.split(",")
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
gootput_config_dict=gootput_config_dict,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
max_concurrency=args.max_concurrency,
|
||||
))
|
||||
|
||||
@ -919,8 +926,8 @@ def main(args: argparse.Namespace):
|
||||
)
|
||||
|
||||
# Traffic
|
||||
result_json["request_rate"] = (
|
||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
||||
result_json["request_rate"] = (args.request_rate if args.request_rate
|
||||
< float("inf") else "inf")
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
@ -1222,5 +1229,12 @@ if __name__ == "__main__":
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. "
|
||||
"If not specified, the model name will be the "
|
||||
"same as the ``--model`` argument. ")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
1147
benchmarks/kernels/benchmark_lora.py
Normal file
1147
benchmarks/kernels/benchmark_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
import ray
|
||||
@ -13,6 +14,9 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
||||
) else torch.float8_e4m3fn
|
||||
|
||||
|
||||
class BenchmarkConfig(TypedDict):
|
||||
BLOCK_SIZE_M: int
|
||||
@ -80,8 +84,8 @@ def benchmark_config(
|
||||
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||
|
||||
w1 = w1.to(torch.float8_e4m3fn)
|
||||
w2 = w2.to(torch.float8_e4m3fn)
|
||||
w1 = w1.to(FP8_DTYPE)
|
||||
w2 = w2.to(FP8_DTYPE)
|
||||
|
||||
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||
|
||||
@ -141,28 +145,172 @@ def benchmark_config(
|
||||
return avg
|
||||
|
||||
|
||||
def get_configs_compute_bound() -> List[Dict[str, int]]:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
def get_rocm_tuning_space(use_fp16):
|
||||
block_mn_range = [16, 32, 64, 128, 256]
|
||||
block_k_range = [16, 32, 64, 128, 256]
|
||||
if not use_fp16:
|
||||
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
||||
num_warps_range = [1, 2, 4, 8]
|
||||
group_m_range = [1, 4, 8, 16, 32]
|
||||
num_stage_range = [2]
|
||||
waves_per_eu_range = [0]
|
||||
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
||||
kpack_range = [1, 2] if use_fp16 else []
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_mn_range,
|
||||
"BLOCK_SIZE_N": block_mn_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
"waves_per_eu": waves_per_eu_range,
|
||||
}
|
||||
if use_fp16:
|
||||
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
||||
param_ranges["kpack"] = kpack_range
|
||||
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
for num_stages in [2, 3, 4, 5]:
|
||||
for block_m in [16, 32, 64, 128, 256]:
|
||||
for block_k in [64, 128, 256]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
for num_warps in [4, 8]:
|
||||
for group_size in [1, 16, 32, 64]:
|
||||
configs.append({
|
||||
"BLOCK_SIZE_M": block_m,
|
||||
"BLOCK_SIZE_N": block_n,
|
||||
"BLOCK_SIZE_K": block_k,
|
||||
"GROUP_SIZE_M": group_size,
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
})
|
||||
|
||||
if current_platform.is_rocm():
|
||||
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||
else:
|
||||
# Reduced search space for faster tuning.
|
||||
# TODO(woosuk): Increase the search space and use a performance model to
|
||||
# prune the search space.
|
||||
block_m_range = [16, 32, 64, 128, 256]
|
||||
block_n_range = [32, 64, 128, 256]
|
||||
block_k_range = [64, 128, 256]
|
||||
num_warps_range = [4, 8]
|
||||
group_m_range = [1, 16, 32, 64]
|
||||
num_stage_range = [2, 3, 4, 5]
|
||||
|
||||
param_ranges = {
|
||||
"BLOCK_SIZE_M": block_m_range,
|
||||
"BLOCK_SIZE_N": block_n_range,
|
||||
"BLOCK_SIZE_K": block_k_range,
|
||||
"GROUP_SIZE_M": group_m_range,
|
||||
"num_warps": num_warps_range,
|
||||
"num_stages": num_stage_range,
|
||||
}
|
||||
|
||||
keys, values = zip(*param_ranges.items())
|
||||
for config_values in product(*values):
|
||||
config = dict(zip(keys, config_values))
|
||||
configs.append(config)
|
||||
return configs
|
||||
|
||||
|
||||
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
|
||||
search_space, is_fp16):
|
||||
N1, K1 = shard_intermediate_size, hidden_size
|
||||
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
|
||||
is_fp16)
|
||||
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
|
||||
is_fp16)
|
||||
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||
return search_space
|
||||
|
||||
|
||||
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
||||
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
||||
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||
pruned_configs = []
|
||||
elemBytes_a = 2 if is_fp16 else 1
|
||||
elemBytes_b = 2 if is_fp16 else 1
|
||||
|
||||
mfma = 16 if M < 32 or N < 32 else 32
|
||||
|
||||
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||
large_gemm = False
|
||||
if M >= 2048 and N >= 2048:
|
||||
large_gemm = True
|
||||
|
||||
for config in configs:
|
||||
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||
num_warps = config.get("num_warps")
|
||||
|
||||
if is_fp16:
|
||||
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||
if matrix_instr_nonkdim > mfma:
|
||||
continue
|
||||
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
# some layouts could not work properly in case
|
||||
# number elements per thread is less 1
|
||||
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
SPLIT_K = config.get("SPLIT_K", 1)
|
||||
GROUP_M = config.get("GROUP_SIZE_M")
|
||||
if is_fp16:
|
||||
if (matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||
or matrix_instr_nonkdim > BLOCK_SIZE_N):
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= M
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_M):
|
||||
continue
|
||||
if (matrix_instr_nonkdim >= N
|
||||
and matrix_instr_nonkdim != BLOCK_SIZE_N):
|
||||
continue
|
||||
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||
# unless BLOCK_SIZE is already small enough
|
||||
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||
continue
|
||||
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||
continue
|
||||
# skip large split_k when not necessary
|
||||
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||
continue
|
||||
# skip split_k that leads to EVEN_K = false
|
||||
leap = SPLIT_K * BLOCK_SIZE_K
|
||||
modv = K % leap
|
||||
if modv != 0:
|
||||
continue
|
||||
# skip large GROUP_M
|
||||
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||
continue
|
||||
# out of shared memory resource
|
||||
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
|
||||
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
|
||||
if LDS > 65536:
|
||||
continue
|
||||
# Skip small block sizes and num_warps for large gemm
|
||||
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||
if large_gemm:
|
||||
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||
continue
|
||||
if BLOCK_SIZE_K < 64:
|
||||
continue
|
||||
if num_warps < 4:
|
||||
continue
|
||||
|
||||
pruned_configs.append(config)
|
||||
|
||||
return pruned_configs
|
||||
|
||||
|
||||
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||
|
||||
|
||||
def merge_unique_dicts(list1, list2):
|
||||
result = []
|
||||
combined_list = list1.copy()
|
||||
combined_list.extend(list2)
|
||||
for dictionary in combined_list:
|
||||
if dictionary not in result:
|
||||
result.append(dictionary)
|
||||
return result
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class BenchmarkWorker:
|
||||
|
||||
@ -170,6 +318,10 @@ class BenchmarkWorker:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(seed)
|
||||
self.seed = seed
|
||||
# Get the device ID to allocate tensors and kernels
|
||||
# on the respective GPU. This is required for Ray to work
|
||||
# correctly with multi-GPU tuning on the ROCm platform.
|
||||
self.device_id = int(ray.get_gpu_ids()[0])
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
@ -217,25 +369,33 @@ class BenchmarkWorker:
|
||||
) -> Dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=10)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
if current_platform.is_rocm():
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = prune_rocm_search_space(num_tokens,
|
||||
shard_intermediate_size,
|
||||
hidden_size, search_space,
|
||||
is_fp16)
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
with torch.cuda.device(self.device_id):
|
||||
for config in tqdm(search_space):
|
||||
try:
|
||||
kernel_time = benchmark_config(config,
|
||||
num_tokens,
|
||||
num_experts,
|
||||
shard_intermediate_size,
|
||||
hidden_size,
|
||||
topk,
|
||||
dtype,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16,
|
||||
num_iters=20)
|
||||
except triton.runtime.autotuner.OutOfResources:
|
||||
# Some configurations may be invalid and fail to compile.
|
||||
continue
|
||||
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
@ -244,12 +404,27 @@ class BenchmarkWorker:
|
||||
|
||||
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
return {
|
||||
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||
"num_warps": config["num_warps"],
|
||||
"num_stages": config["num_stages"],
|
||||
"BLOCK_SIZE_M":
|
||||
config["BLOCK_SIZE_M"],
|
||||
"BLOCK_SIZE_N":
|
||||
config["BLOCK_SIZE_N"],
|
||||
"BLOCK_SIZE_K":
|
||||
config["BLOCK_SIZE_K"],
|
||||
"GROUP_SIZE_M":
|
||||
config["GROUP_SIZE_M"],
|
||||
"num_warps":
|
||||
config["num_warps"],
|
||||
"num_stages":
|
||||
config["num_stages"],
|
||||
**({
|
||||
"waves_per_eu": config["waves_per_eu"]
|
||||
} if "waves_per_eu" in config else {}),
|
||||
**({
|
||||
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
|
||||
} if "matrix_instr_nonkdim" in config else {}),
|
||||
**({
|
||||
"kpack": config["kpack"]
|
||||
} if "kpack" in config else {}),
|
||||
}
|
||||
|
||||
|
||||
@ -275,7 +450,8 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
|
||||
config = AutoConfig.from_pretrained(args.model)
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.model, trust_remote_code=args.trust_remote_code)
|
||||
if config.architectures[0] == "DbrxForCausalLM":
|
||||
E = config.ffn_config.moe_num_experts
|
||||
topk = config.ffn_config.moe_top_k
|
||||
@ -286,6 +462,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] == "DeepseekV3ForCausalLM":
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Default: Mixtral.
|
||||
E = config.num_local_experts
|
||||
@ -294,7 +475,7 @@ def main(args: argparse.Namespace):
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
dtype = config.torch_dtype
|
||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||
|
||||
@ -322,7 +503,8 @@ def main(args: argparse.Namespace):
|
||||
return ray.get(outputs)
|
||||
|
||||
if args.tune:
|
||||
search_space = get_configs_compute_bound()
|
||||
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||
search_space = get_configs_compute_bound(is_fp16)
|
||||
print(f"Start tuning over {len(search_space)} configurations...")
|
||||
|
||||
start = time.time()
|
||||
@ -362,6 +544,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
@ -98,7 +98,9 @@ def main(
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
k_scale = v_scale = 1.0
|
||||
k_scale = v_scale = torch.tensor(1.0,
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
|
210
benchmarks/kernels/utils.py
Normal file
210
benchmarks/kernels/utils.py
Normal file
@ -0,0 +1,210 @@
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Iterable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as TBenchmark
|
||||
from torch.utils.benchmark import Measurement as TMeasurement
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CudaGraphBenchParams:
|
||||
num_ops_in_cuda_graph: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ArgPool:
|
||||
"""
|
||||
When some argument of the benchmarking function is annotated with this type,
|
||||
the benchmarking class (BenchMM) will collapse the argument to a pick a
|
||||
single value from the given list of values, during function invocation.
|
||||
For every invocation during a benchmarking run, it will choose a
|
||||
different value from the list.
|
||||
"""
|
||||
values: Iterable[Any]
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.values[index]
|
||||
|
||||
|
||||
class Bench:
|
||||
|
||||
class ArgsIterator:
|
||||
|
||||
def __init__(self, args_list, kwargs_list):
|
||||
assert len(args_list) == len(kwargs_list)
|
||||
self.args_list = args_list
|
||||
self.kwargs_list = kwargs_list
|
||||
self.n = len(self.args_list)
|
||||
self.idx = 0
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
yield (self.args_list[self.idx], self.kwargs_list[self.idx])
|
||||
self.idx += 1
|
||||
self.idx = self.idx % self.n
|
||||
|
||||
def reset(self):
|
||||
self.idx = 0
|
||||
|
||||
@property
|
||||
def n_args(self):
|
||||
return self.n
|
||||
|
||||
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams],
|
||||
label: str, sub_label: str, description: str, fn: Callable,
|
||||
*args, **kwargs):
|
||||
|
||||
self.cuda_graph_params = cuda_graph_params
|
||||
self.use_cuda_graph = self.cuda_graph_params is not None
|
||||
self.label = label
|
||||
self.sub_label = sub_label
|
||||
self.description = description
|
||||
self.fn = fn
|
||||
|
||||
# Process args
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self.args_list, self.kwargs_list = self.collapse_argpool(
|
||||
*args, **kwargs)
|
||||
self.args_iterator = self.ArgsIterator(self.args_list,
|
||||
self.kwargs_list)
|
||||
|
||||
# Cudagraph runner
|
||||
self.g = None
|
||||
if self.use_cuda_graph:
|
||||
self.g = self.get_cuda_graph_runner()
|
||||
|
||||
# benchmark run params
|
||||
self.min_run_time = 1
|
||||
|
||||
def collapse_argpool(self, *args, **kwargs):
|
||||
argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [
|
||||
arg for arg in kwargs.values() if isinstance(arg, ArgPool)
|
||||
]
|
||||
if len(argpool_args) == 0:
|
||||
return [args], [kwargs]
|
||||
|
||||
# Make sure all argpools are of the same size
|
||||
argpool_size = len(argpool_args[0].values)
|
||||
assert all([argpool_size == len(arg.values) for arg in argpool_args])
|
||||
|
||||
# create copies of the args
|
||||
args_list = []
|
||||
kwargs_list = []
|
||||
for _ in range(argpool_size):
|
||||
args_list.append(args)
|
||||
kwargs_list.append(kwargs.copy())
|
||||
|
||||
for i in range(argpool_size):
|
||||
# collapse args; Just pick the ith value
|
||||
args_list[i] = tuple([
|
||||
arg[i] if isinstance(arg, ArgPool) else arg
|
||||
for arg in args_list[i]
|
||||
])
|
||||
|
||||
# collapse kwargs
|
||||
kwargs_i = kwargs_list[i]
|
||||
arg_pool_keys = [
|
||||
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
|
||||
]
|
||||
for k in arg_pool_keys:
|
||||
# again just pick the ith value
|
||||
kwargs_i[k] = kwargs_i[k][i]
|
||||
kwargs_list[i] = kwargs_i
|
||||
|
||||
return args_list, kwargs_list
|
||||
|
||||
def get_cuda_graph_runner(self):
|
||||
assert self.use_cuda_graph
|
||||
assert self.args_iterator is not None
|
||||
|
||||
num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph
|
||||
|
||||
# warmup
|
||||
args_it = self.args_iterator.__next__()
|
||||
for _ in range(2):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
for _ in range(num_graph_ops):
|
||||
args, kwargs = next(args_it)
|
||||
self.fn(*args, **kwargs)
|
||||
return g
|
||||
|
||||
def run_cudagrah(self) -> TMeasurement:
|
||||
assert self.use_cuda_graph
|
||||
globals = {'g': self.g}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt="g.replay()",
|
||||
globals=globals,
|
||||
label=(
|
||||
f"{self.label}"
|
||||
f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"
|
||||
),
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run_eager(self) -> TMeasurement:
|
||||
setup = None
|
||||
stmt = None
|
||||
globals = None
|
||||
|
||||
has_arg_pool = self.args_iterator.n_args > 1
|
||||
if has_arg_pool:
|
||||
setup = '''
|
||||
args_iterator.reset()
|
||||
args_it = args_iterator.__next__()
|
||||
'''
|
||||
stmt = '''
|
||||
args, kwargs = next(args_it)
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args_iterator': self.args_iterator}
|
||||
else:
|
||||
# no arg pool. Just use the args and kwargs directly
|
||||
self.args_iterator.reset()
|
||||
args_it = self.args_iterator.__next__()
|
||||
args, kwargs = next(args_it)
|
||||
|
||||
setup = ""
|
||||
stmt = '''
|
||||
fn(*args, **kwargs)
|
||||
'''
|
||||
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs}
|
||||
|
||||
return TBenchmark.Timer(
|
||||
stmt=stmt,
|
||||
setup=setup,
|
||||
globals=globals,
|
||||
label=self.label,
|
||||
sub_label=self.sub_label,
|
||||
description=self.description,
|
||||
).blocked_autorange(min_run_time=self.min_run_time)
|
||||
|
||||
def run(self) -> TMeasurement:
|
||||
timer = None
|
||||
if self.use_cuda_graph: # noqa SIM108
|
||||
timer = self.run_cudagrah()
|
||||
else:
|
||||
timer = self.run_eager()
|
||||
if not timer.meets_confidence() or timer.has_warnings:
|
||||
print("Doesn't meet confidence - re-running bench ...")
|
||||
return self.run()
|
||||
return timer
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
if exc_type:
|
||||
print(f"exc type {exc_type}")
|
||||
print(f"exc value {exc_value}")
|
||||
print(f"exc traceback {traceback}")
|
@ -4,6 +4,11 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
||||
set(MACOSX_FOUND TRUE)
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
@ -13,6 +18,9 @@ endif()
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
|
||||
set (ENABLE_NUMA TRUE)
|
||||
|
||||
#
|
||||
# Check the compile flags
|
||||
#
|
||||
@ -22,18 +30,28 @@ if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
|
||||
"-mf16c"
|
||||
)
|
||||
endif()
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-fopenmp"
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
|
||||
execute_process(COMMAND cat /proc/cpuinfo
|
||||
RESULT_VARIABLE CPUINFO_RET
|
||||
OUTPUT_VARIABLE CPUINFO)
|
||||
|
||||
if (NOT CPUINFO_RET EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
|
||||
if(MACOSX_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-Xpreprocessor"
|
||||
"-fopenmp"
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
else()
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
"-fopenmp"
|
||||
"-DVLLM_CPU_EXTENSION")
|
||||
endif()
|
||||
|
||||
if (NOT MACOSX_FOUND)
|
||||
execute_process(COMMAND cat /proc/cpuinfo
|
||||
RESULT_VARIABLE CPUINFO_RET
|
||||
OUTPUT_VARIABLE CPUINFO)
|
||||
if (NOT CPUINFO_RET EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
function (find_isa CPUINFO TARGET OUT)
|
||||
string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
|
||||
if(NOT ISA_FOUND EQUAL -1)
|
||||
@ -54,12 +72,17 @@ endfunction()
|
||||
|
||||
is_avx512_disabled(AVX512_DISABLED)
|
||||
|
||||
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
|
||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
|
||||
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
||||
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
|
||||
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
|
||||
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
|
||||
set(APPLE_SILICON_FOUND TRUE)
|
||||
else()
|
||||
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
|
||||
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
|
||||
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
|
||||
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
|
||||
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
|
||||
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
|
||||
endif()
|
||||
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
list(APPEND CXX_COMPILE_FLAGS
|
||||
@ -103,6 +126,9 @@ elseif (ASIMD_FOUND)
|
||||
set(MARCH_FLAGS "-march=armv8.2-a+dotprod+fp16")
|
||||
endif()
|
||||
list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS})
|
||||
elseif(APPLE_SILICON_FOUND)
|
||||
message(STATUS "Apple Silicon Detected")
|
||||
set(ENABLE_NUMA OFF)
|
||||
else()
|
||||
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA or ARMv8 support.")
|
||||
endif()
|
||||
@ -139,7 +165,12 @@ endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
|
||||
list(APPEND LIBS numa)
|
||||
if(ENABLE_NUMA)
|
||||
list(APPEND LIBS numa)
|
||||
else()
|
||||
message(STATUS "NUMA is disabled")
|
||||
add_compile_definitions(-DVLLM_NUMA_DISABLED)
|
||||
endif()
|
||||
|
||||
#
|
||||
# _C extension
|
||||
|
@ -58,8 +58,8 @@ function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
|
||||
#
|
||||
set(SRCS ${ORIG_SRCS})
|
||||
set(CXX_SRCS ${ORIG_SRCS})
|
||||
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
|
||||
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
|
||||
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)|(hip)$")
|
||||
|
||||
#
|
||||
# Generate ROCm/HIP source file names from CUDA file names.
|
||||
@ -259,7 +259,7 @@ endmacro()
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
||||
# 9.0a to the result.
|
||||
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
@ -270,34 +270,47 @@ endmacro()
|
||||
#
|
||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||
|
||||
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
||||
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
|
||||
# less or eqault to ARCH
|
||||
foreach(_ARCH ${CUDA_ARCHS})
|
||||
set(_TMP_ARCH)
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
set(_TMP_ARCH ${_SRC_ARCH})
|
||||
else()
|
||||
break()
|
||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||
# compatibility is only forward compatible within the same major version).
|
||||
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
||||
set(_TMP_ARCH)
|
||||
# Extract the major version of the target arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
# Extract the major version of the source arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||
# Check major-version match AND version-less-or-equal
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||
endif()
|
||||
else()
|
||||
# If we hit a version greater than the target, we can break
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# If we found a matching _TMP_ARCH, append it to _CUDA_ARCHS
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS "${_TMP_ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
if (_TMP_ARCH)
|
||||
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
|
@ -9,8 +9,16 @@
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
bool act_first>
|
||||
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
|
||||
const scalar_t& y) {
|
||||
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
|
||||
}
|
||||
// Activation and gating kernel template.
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||
|
||||
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
|
||||
bool act_first>
|
||||
__global__ void act_and_mul_kernel(
|
||||
scalar_t* __restrict__ out, // [..., d]
|
||||
const scalar_t* __restrict__ input, // [..., 2, d]
|
||||
@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x) * y;
|
||||
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
|
||||
}
|
||||
}
|
||||
|
||||
@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
} // namespace vllm
|
||||
|
||||
// Launch activation and gating kernel.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
|
||||
// Use ACT_FIRST (bool) indicating whether to apply the activation function
|
||||
// first.
|
||||
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
|
||||
int d = input.size(-1) / 2; \
|
||||
int64_t num_tokens = input.numel() / input.size(-1); \
|
||||
dim3 grid(num_tokens); \
|
||||
@ -64,7 +74,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "act_and_mul_kernel", [&] { \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
|
||||
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
|
||||
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
|
||||
input.data_ptr<scalar_t>(), d); \
|
||||
});
|
||||
@ -72,19 +82,27 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
|
||||
void silu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
|
||||
}
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
|
||||
// applies the silu to the latter half of the input.
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
|
||||
}
|
||||
|
||||
void gelu_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
|
||||
}
|
||||
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input) // [..., 2 * d]
|
||||
{
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
|
||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
@ -105,7 +105,7 @@ __device__ void paged_attention_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||
k_vec_quant, k_scale);
|
||||
k_vec_quant, *k_scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
|
||||
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||
v_scale);
|
||||
*v_scale);
|
||||
}
|
||||
if (block_idx == num_seq_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||
@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
@ -549,7 +549,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const float* k_scale, const float* v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
|
@ -41,7 +41,7 @@
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
|
||||
blocksparse_vert_stride, blocksparse_block_size, \
|
||||
blocksparse_head_sliding_step);
|
||||
|
||||
@ -53,10 +53,10 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -80,6 +80,8 @@ void paged_attention_v1_launcher(
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int padded_max_seq_len =
|
||||
@ -176,9 +178,10 @@ void paged_attention_v1(
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
|
@ -37,7 +37,7 @@
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
|
||||
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
|
||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||
blocksparse_block_size, blocksparse_head_sliding_step); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||
@ -54,10 +54,10 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -84,6 +84,8 @@ void paged_attention_v2_launcher(
|
||||
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
|
||||
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
||||
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
@ -187,9 +189,10 @@ void paged_attention_v2(
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
|
@ -18,15 +18,15 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale);
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
const double k_scale, const double v_scale);
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
|
@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
// block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x, const float k_scale,
|
||||
const float v_scale) {
|
||||
const int head_size, const int block_size, const int x,
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
||||
value_cache[tgt_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float k_scale, const float v_scale) {
|
||||
const float* k_scale, const float* v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
value_cache[tgt_key_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
||||
value_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||
num_heads, head_size, block_size, x, k_scale, v_scale);
|
||||
num_heads, head_size, block_size, x, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -268,8 +270,8 @@ void reshape_and_cache(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
@ -299,7 +301,9 @@ void reshape_and_cache(
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
|
||||
value_stride, num_heads, head_size, block_size, \
|
||||
reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -308,8 +312,8 @@ void reshape_and_cache_flash(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
|
@ -32,7 +32,7 @@ class ScalarType {
|
||||
signed_(signed_),
|
||||
bias(bias),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr){};
|
||||
nan_repr(nan_repr) {};
|
||||
|
||||
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
|
||||
return ScalarType(0, size_bits - 1, true, bias);
|
||||
|
@ -386,7 +386,7 @@ void paged_attention_v1_impl_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -459,12 +459,12 @@ void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
@ -702,7 +702,7 @@ void paged_attention_v2_impl_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
|
||||
int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
|
||||
int max_seq_len, const std::optional<torch::Tensor>& alibi_slopes) {
|
||||
int num_seqs = query.size(0);
|
||||
int num_heads = query.size(1);
|
||||
int head_size = query.size(2);
|
||||
@ -781,12 +781,12 @@ void paged_attention_v2(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
|
@ -107,10 +107,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
|
@ -2,13 +2,13 @@
|
||||
#define CPU_TYPES_HPP
|
||||
|
||||
#if defined(__x86_64__)
|
||||
//x86 implementation
|
||||
// x86 implementation
|
||||
#include "cpu_types_x86.hpp"
|
||||
#elif defined(__POWER9_VECTOR__)
|
||||
//ppc implementation
|
||||
// ppc implementation
|
||||
#include "cpu_types_vsx.hpp"
|
||||
#elif defined(__aarch64__)
|
||||
//arm implementation
|
||||
// arm implementation
|
||||
#include "cpu_types_arm.hpp"
|
||||
#else
|
||||
#warning "unsupported vLLM cpu implementation"
|
||||
|
@ -1,48 +1,50 @@
|
||||
#include <arm_neon.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/all.h>
|
||||
#include <cmath>
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
#else
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
};
|
||||
};
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
};
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
|
||||
};
|
||||
|
||||
@ -54,53 +56,106 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
|
||||
float16x8_t reg;
|
||||
|
||||
explicit FP16Vec8(const void *ptr)
|
||||
: reg(vld1q_f16(static_cast<const __fp16 *>(ptr))) {};
|
||||
explicit FP16Vec8(const void* ptr)
|
||||
: reg(vld1q_f16(static_cast<const __fp16*>(ptr))) {};
|
||||
|
||||
explicit FP16Vec8(const FP32Vec8 &);
|
||||
explicit FP16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const {
|
||||
vst1q_f16(static_cast<__fp16 *>(ptr), reg);
|
||||
}
|
||||
void save(void* ptr) const { vst1q_f16(static_cast<__fp16*>(ptr), reg); }
|
||||
};
|
||||
|
||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
float16x8x2_t reg;
|
||||
|
||||
explicit FP16Vec16(const void *ptr) {
|
||||
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||
}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16& vec);
|
||||
|
||||
void save(void *ptr) const {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
|
||||
void save(void *ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
if (full_blocks > 1) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
}
|
||||
|
||||
if (remainder > 0) {
|
||||
float16x8_t temp = reg.val[full_blocks];
|
||||
for (int i = 0; i < remainder; ++i) {
|
||||
reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] = vgetq_lane_f16(temp, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
|
||||
float16x8x2_t reg;
|
||||
|
||||
explicit FP16Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr));
|
||||
reg.val[1] = vld1q_f16(reinterpret_cast<const __fp16*>(ptr) + 8);
|
||||
}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16& vec);
|
||||
|
||||
void save(void* ptr) const {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
if (full_blocks > 1) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr) + 8, reg.val[1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Note: below is the unrolled version of the following code:
|
||||
//
|
||||
// for (int i = 0; i < remainder; ++i) {
|
||||
// reinterpret_cast<__fp16*>(ptr)[full_blocks * 8 + i] =
|
||||
// vgetq_lane_f16(temp, i);
|
||||
// }
|
||||
//
|
||||
// For macOS build (Clang), the arm/neon intrinsics function
|
||||
// `vgetq_lane_f16` needs the parameter `i` to be constant at compile
|
||||
// time.
|
||||
|
||||
if (remainder > 0) {
|
||||
float16x8_t temp = reg.val[full_blocks];
|
||||
__fp16* fp16_ptr = reinterpret_cast<__fp16*>(ptr);
|
||||
switch (remainder) {
|
||||
case 1:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
break;
|
||||
case 2:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
break;
|
||||
case 3:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
break;
|
||||
case 4:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
break;
|
||||
case 5:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
break;
|
||||
case 6:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
break;
|
||||
case 7:
|
||||
fp16_ptr[full_blocks * 8 + 0] = vgetq_lane_f16(temp, 0);
|
||||
fp16_ptr[full_blocks * 8 + 1] = vgetq_lane_f16(temp, 1);
|
||||
fp16_ptr[full_blocks * 8 + 2] = vgetq_lane_f16(temp, 2);
|
||||
fp16_ptr[full_blocks * 8 + 3] = vgetq_lane_f16(temp, 3);
|
||||
fp16_ptr[full_blocks * 8 + 4] = vgetq_lane_f16(temp, 4);
|
||||
fp16_ptr[full_blocks * 8 + 5] = vgetq_lane_f16(temp, 5);
|
||||
fp16_ptr[full_blocks * 8 + 6] = vgetq_lane_f16(temp, 6);
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
@ -108,16 +163,17 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
|
||||
bfloat16x8_t reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8_t *>(ptr)) {};
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec8(bfloat16x8_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
explicit BF16Vec8(float32x4x2_t v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
||||
explicit BF16Vec8(float32x4x2_t v)
|
||||
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1])) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8_t *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -125,19 +181,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
bfloat16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x2_t *>(ptr)) {};
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x2_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec16(bfloat16x8x2_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
explicit BF16Vec16(float32x4x4_t v) : reg({
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])
|
||||
}){};
|
||||
explicit BF16Vec16(float32x4x4_t v)
|
||||
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[0]), v.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x2_t *>(ptr) = reg; };
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -145,19 +200,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
bfloat16x8x4_t reg;
|
||||
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x4_t *>(ptr)) {};
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const bfloat16x8x4_t*>(ptr)) {};
|
||||
|
||||
explicit BF16Vec32(bfloat16x8x4_t data) : reg(data) {};
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg
|
||||
}) {};
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<bfloat16x8x4_t *>(ptr) = reg; };
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -175,11 +226,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(vdupq_n_f32(0.0f)) {};
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(vld1q_f32(ptr)) {};
|
||||
explicit FP32Vec4(const float* ptr) : reg(vld1q_f32(ptr)) {};
|
||||
|
||||
explicit FP32Vec4(float32x4_t data) : reg(data) {};
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {};
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -195,32 +246,37 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
|
||||
explicit FP32Vec8() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {};
|
||||
|
||||
explicit FP32Vec8(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
||||
explicit FP32Vec8(const float* ptr)
|
||||
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4)}) {};
|
||||
|
||||
explicit FP32Vec8(float32x4x2_t data) : reg(data) {};
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {};
|
||||
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8 &v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
||||
};
|
||||
explicit FP32Vec8(const FP16Vec8& v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg));
|
||||
};
|
||||
|
||||
explicit FP32Vec8(float16x8_t v) : reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
||||
explicit FP32Vec8(float16x8_t v)
|
||||
: reg({vcvt_f32_f16(vget_low_f16(v)), vcvt_f32_f16(vget_high_f16(v))}) {};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
|
||||
explicit FP32Vec8(bfloat16x8_t v) : reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
||||
explicit FP32Vec8(bfloat16x8_t v)
|
||||
: reg({vcvtq_low_f32_bf16(v), vcvtq_high_f32_bf16(v)}) {};
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v) : reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
||||
explicit FP32Vec8(const BF16Vec8& v)
|
||||
: reg({vcvtq_low_f32_bf16(v.reg), vcvtq_high_f32_bf16(v.reg)}) {};
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float answer = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
|
||||
return answer;
|
||||
}
|
||||
@ -267,10 +323,14 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
|
||||
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])), static_cast<float32_t>(erf(ar.values[1]))};
|
||||
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])), static_cast<float32_t>(erf(ar.values[3]))};
|
||||
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])), static_cast<float32_t>(erf(ar.values[5]))};
|
||||
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])), static_cast<float32_t>(erf(ar.values[7]))};
|
||||
float32x2_t er_vec0 = {static_cast<float32_t>(erf(ar.values[0])),
|
||||
static_cast<float32_t>(erf(ar.values[1]))};
|
||||
float32x2_t er_vec1 = {static_cast<float32_t>(erf(ar.values[2])),
|
||||
static_cast<float32_t>(erf(ar.values[3]))};
|
||||
float32x2_t er_vec2 = {static_cast<float32_t>(erf(ar.values[4])),
|
||||
static_cast<float32_t>(erf(ar.values[5]))};
|
||||
float32x2_t er_vec3 = {static_cast<float32_t>(erf(ar.values[6])),
|
||||
static_cast<float32_t>(erf(ar.values[7]))};
|
||||
|
||||
float32x4_t result0 = vcombine_f32(er_vec0, er_vec1);
|
||||
float32x4_t result1 = vcombine_f32(er_vec2, er_vec3);
|
||||
@ -280,25 +340,29 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
result.val[1] = result1;
|
||||
|
||||
return FP32Vec8(result);
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]), vmulq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]), vsubq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]), vdivq_f32(reg.val[1], b.reg.val[1])}));
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(float32x4x2_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1])}));
|
||||
}
|
||||
|
||||
void save(float* ptr) const {
|
||||
vst1q_f32(ptr, reg.val[0]);
|
||||
vst1q_f32(ptr + 4, reg.val[1]);
|
||||
}
|
||||
@ -313,103 +377,100 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float32x4x4_t reg;
|
||||
|
||||
explicit FP32Vec16(float v) : reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
||||
explicit FP32Vec16(float v)
|
||||
: reg({vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v), vmovq_n_f32(v)}) {}
|
||||
|
||||
explicit FP32Vec16() : reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0)}) {}
|
||||
explicit FP32Vec16()
|
||||
: reg({vmovq_n_f32(0.0), vmovq_n_f32(0.0), vmovq_n_f32(0.0),
|
||||
vmovq_n_f32(0.0)}) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8), vld1q_f32(ptr + 12)}) {}
|
||||
explicit FP32Vec16(const float* ptr)
|
||||
: reg({vld1q_f32(ptr), vld1q_f32(ptr + 4), vld1q_f32(ptr + 8),
|
||||
vld1q_f32(ptr + 12)}) {}
|
||||
|
||||
explicit FP32Vec16(float32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v.reg)) {}
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(bfloat16x8x2_t v) : reg({
|
||||
vcvtq_low_f32_bf16(v.val[0]),
|
||||
vcvtq_high_f32_bf16(v.val[0]),
|
||||
vcvtq_low_f32_bf16(v.val[1]),
|
||||
vcvtq_high_f32_bf16(v.val[1])
|
||||
}) {};
|
||||
#endif
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(bfloat16x8x2_t v)
|
||||
: reg({vcvtq_low_f32_bf16(v.val[0]), vcvtq_high_f32_bf16(v.val[0]),
|
||||
vcvtq_low_f32_bf16(v.val[1]), vcvtq_high_f32_bf16(v.val[1])}) {};
|
||||
#endif
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(const BF16Vec16 &v) : reg({
|
||||
vcvtq_low_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_low_f32_bf16(v.reg.val[1]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[1])
|
||||
}) {};
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
explicit FP32Vec16(const BF16Vec16& v)
|
||||
: reg({vcvtq_low_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[0]),
|
||||
vcvtq_low_f32_bf16(v.reg.val[1]),
|
||||
vcvtq_high_f32_bf16(v.reg.val[1])}) {};
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
#endif
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {};
|
||||
#endif
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
reg.val[0] = vcvt_f32_f16(vget_low_f16(v.reg.val[0]));
|
||||
reg.val[1] = vcvt_f32_f16(vget_high_f16(v.reg.val[0]));
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
vaddq_f32(reg.val[2], b.reg.val[2]),
|
||||
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
vaddq_f32(reg.val[2], b.reg.val[2]),
|
||||
vaddq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1]),
|
||||
vmulq_f32(reg.val[2], b.reg.val[2]),
|
||||
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmulq_f32(reg.val[0], b.reg.val[0]),
|
||||
vmulq_f32(reg.val[1], b.reg.val[1]),
|
||||
vmulq_f32(reg.val[2], b.reg.val[2]),
|
||||
vmulq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1]),
|
||||
vsubq_f32(reg.val[2], b.reg.val[2]),
|
||||
vsubq_f32(reg.val[3], b.reg.val[3])
|
||||
}));
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vsubq_f32(reg.val[0], b.reg.val[0]),
|
||||
vsubq_f32(reg.val[1], b.reg.val[1]),
|
||||
vsubq_f32(reg.val[2], b.reg.val[2]),
|
||||
vsubq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1]),
|
||||
vdivq_f32(reg.val[2], b.reg.val[2]),
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])
|
||||
}));
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vdivq_f32(reg.val[0], b.reg.val[0]),
|
||||
vdivq_f32(reg.val[1], b.reg.val[1]),
|
||||
vdivq_f32(reg.val[2], b.reg.val[2]),
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float answer = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&answer, &ar](int i) { answer += ar.values[i]; });
|
||||
|
||||
return answer;
|
||||
};
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
@ -422,7 +483,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vst1q_f32(ptr, reg.val[0]);
|
||||
vst1q_f32(ptr + 4, reg.val[1]);
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
@ -430,43 +491,59 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::Half> {
|
||||
using vec_type = FP16Vec8;
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
|
||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||
*reinterpret_cast<__fp16 *>(ptr) = v;
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v) {
|
||||
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
||||
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
||||
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
|
||||
template <>
|
||||
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||
*reinterpret_cast<__fp16*>(ptr) = v;
|
||||
}
|
||||
|
||||
reg.val[0] = vcombine_f16(low_0, high_0);
|
||||
reg.val[1] = vcombine_f16(low_1, high_1);
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
|
||||
float16x4_t low_0 = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t high_0 = vcvt_f16_f32(v.reg.val[1]);
|
||||
float16x4_t low_1 = vcvt_f16_f32(v.reg.val[2]);
|
||||
float16x4_t high_1 = vcvt_f16_f32(v.reg.val[3]);
|
||||
|
||||
reg.val[0] = vcombine_f16(low_0, high_0);
|
||||
reg.val[1] = vcombine_f16(low_1, high_1);
|
||||
};
|
||||
|
||||
inline FP16Vec8 :: FP16Vec8(const FP32Vec8 &v) {
|
||||
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
||||
inline FP16Vec8 ::FP16Vec8(const FP32Vec8& v) {
|
||||
float16x4_t lower_half = vcvt_f16_f32(v.reg.val[0]);
|
||||
float16x4_t upper_half = vcvt_f16_f32(v.reg.val[1]);
|
||||
|
||||
reg = vcombine_f16(lower_half, upper_half);
|
||||
reg = vcombine_f16(lower_half, upper_half);
|
||||
};
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc.reg.val[0] = vfmaq_f32(acc.reg.val[0], a.reg.val[0], b.reg.val[0]);
|
||||
acc.reg.val[1] = vfmaq_f32(acc.reg.val[1], a.reg.val[1], b.reg.val[1]);
|
||||
acc.reg.val[2] = vfmaq_f32(acc.reg.val[2], a.reg.val[2], b.reg.val[2]);
|
||||
@ -474,8 +551,7 @@ inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
};
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
|
||||
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||
float32x4_t a0_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[0]));
|
||||
float32x4_t a0_high = vcvt_f32_bf16(vget_high_bf16(a.reg.val[0]));
|
||||
float32x4_t a1_low = vcvt_f32_bf16(vget_low_bf16(a.reg.val[1]));
|
||||
@ -494,22 +570,22 @@ inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
#endif
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) : reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {};
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1])) {
|
||||
};
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) : reg({
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]), v.reg.val[3])
|
||||
}){};
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg({vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[0]), v.reg.val[1]),
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.reg.val[2]),
|
||||
v.reg.val[3])}) {};
|
||||
#endif
|
||||
|
||||
inline void prefetch(const void *addr) {
|
||||
__builtin_prefetch(addr, 0, 1);
|
||||
};
|
||||
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); };
|
||||
|
||||
#ifdef ARM_BF16_SUPPORT
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*reinterpret_cast<__bf16 *>(ptr) = vcvth_bf16_f32(v);
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
*reinterpret_cast<__bf16*>(ptr) = vcvth_bf16_f32(v);
|
||||
};
|
||||
#endif
|
||||
};
|
||||
}; // namespace vec_op
|
@ -9,38 +9,40 @@
|
||||
namespace vec_op {
|
||||
|
||||
// FIXME: FP16 is not fully supported in Torch-CPU
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
std::cout << #NAME << " invoked." << std::endl;
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME) \
|
||||
std::cout << #NAME << " exit." << std::endl;
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
@ -68,12 +70,14 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
|
||||
__vector signed short reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {}
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg((__vector signed short)vec_xl(0, (__vector signed short*)ptr)) {}
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; }
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__vector signed short*>(ptr) = reg;
|
||||
}
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -81,18 +85,18 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
ss16x8x2_t reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr) {
|
||||
explicit BF16Vec16(const void* ptr) {
|
||||
// Load 256 bits in two parts
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr);
|
||||
reg.val[0] = (__vector signed short)vec_xl(0, (signed short*)ptr);
|
||||
reg.val[1] = (__vector signed short)vec_xl(16, (signed short*)ptr);
|
||||
}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const {
|
||||
void save(void* ptr) const {
|
||||
// Save 256 bits in two parts
|
||||
vec_xst(reg.val[0], 0, (signed short *)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short *)ptr);
|
||||
vec_xst(reg.val[0], 0, (signed short*)ptr);
|
||||
vec_xst(reg.val[1], 16, (signed short*)ptr);
|
||||
}
|
||||
};
|
||||
|
||||
@ -102,19 +106,15 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
constexpr static int VEC_ELEM_NUM = 32;
|
||||
|
||||
ss16x8x4_t reg;
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {}
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg(*reinterpret_cast<const ss16x8x4_t*>(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg,
|
||||
vec8_data.reg
|
||||
}) {}
|
||||
explicit BF16Vec32(const BF16Vec8& vec8_data)
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<ss16x8x4_t*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
@ -130,11 +130,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {}
|
||||
explicit FP32Vec4(const float* ptr) : reg(vec_xl(0, ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__vector float data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -156,19 +156,19 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
reg.val[1] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const float *ptr) {
|
||||
explicit FP32Vec8(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
}
|
||||
|
||||
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) {
|
||||
explicit FP32Vec8(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v) {
|
||||
explicit FP32Vec8(const BF16Vec8& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
|
||||
}
|
||||
@ -177,7 +177,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -230,23 +231,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(
|
||||
{vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
}
|
||||
@ -275,7 +280,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[3] = vec_splats(0.0f);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) {
|
||||
explicit FP32Vec16(const float* ptr) {
|
||||
reg.val[0] = vec_xl(0, ptr);
|
||||
reg.val[1] = vec_xl(16, ptr);
|
||||
reg.val[2] = vec_xl(32, ptr);
|
||||
@ -284,78 +289,76 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) {
|
||||
explicit FP32Vec16(const FP32Vec16& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[2];
|
||||
reg.val[3] = data.reg.val[3];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data) {
|
||||
explicit FP32Vec16(const FP32Vec4& data) {
|
||||
reg.val[0] = data.reg;
|
||||
reg.val[1] = data.reg;
|
||||
reg.val[2] = data.reg;
|
||||
reg.val[3] = data.reg;
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data) {
|
||||
explicit FP32Vec16(const FP32Vec8& data) {
|
||||
reg.val[0] = data.reg.val[0];
|
||||
reg.val[1] = data.reg.val[1];
|
||||
reg.val[2] = data.reg.val[0];
|
||||
reg.val[3] = data.reg.val[1];
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
|
||||
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
|
||||
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
|
||||
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
|
||||
vec_mul(reg.val[1], b.reg.val[1]),
|
||||
vec_mul(reg.val[2], b.reg.val[2]),
|
||||
vec_mul(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_add(reg.val[0], b.reg.val[0]),
|
||||
vec_add(reg.val[1], b.reg.val[1]),
|
||||
vec_add(reg.val[2], b.reg.val[2]),
|
||||
vec_add(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_sub(reg.val[0], b.reg.val[0]),
|
||||
vec_sub(reg.val[1], b.reg.val[1]),
|
||||
vec_sub(reg.val[2], b.reg.val[2]),
|
||||
vec_sub(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
return FP32Vec16(f32x4x4_t({
|
||||
vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(f32x4x4_t({vec_div(reg.val[0], b.reg.val[0]),
|
||||
vec_div(reg.val[1], b.reg.val[1]),
|
||||
vec_div(reg.val[2], b.reg.val[2]),
|
||||
vec_div(reg.val[3], b.reg.val[3])}));
|
||||
}
|
||||
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
|
||||
AliasReg ar;
|
||||
@ -368,7 +371,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return result;
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
vec_xst(reg.val[0], 0, ptr);
|
||||
vec_xst(reg.val[1], 16, ptr);
|
||||
vec_xst(reg.val[2], 32, ptr);
|
||||
@ -376,43 +379,62 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifndef __VEC_CLASS_FP_NAN
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#define __VEC_CLASS_FP_NAN (1 << 6)
|
||||
#endif
|
||||
|
||||
const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 };
|
||||
const static __vector unsigned char omask = {0, 1, 4, 5, 8, 9, 12, 13,
|
||||
16, 17, 20, 21, 24, 25, 28, 29};
|
||||
#ifndef _ARCH_PWR10
|
||||
const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff };
|
||||
const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 };
|
||||
const static __vector unsigned int sh16 = { 16, 16, 16, 16 };
|
||||
const static __vector unsigned int one = { 1, 1, 1, 1 };
|
||||
const static __vector unsigned int bias = {0x00007fff, 0x00007fff, 0x00007fff,
|
||||
0x00007fff};
|
||||
const static __vector unsigned int nan = {0x7fc00000, 0x7fc00000, 0x7fc00000,
|
||||
0x7fc00000};
|
||||
const static __vector unsigned int sh16 = {16, 16, 16, 16};
|
||||
const static __vector unsigned int one = {1, 1, 1, 1};
|
||||
#endif
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) {
|
||||
#ifdef _ARCH_PWR10
|
||||
__vector signed short ret[2];
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[1]);
|
||||
reg = vec_perm(ret[0], ret[1], omask);
|
||||
#elif defined(_ARCH_PWR9)
|
||||
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
|
||||
@ -425,8 +447,10 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
__vector unsigned int rnd1 = vec_add(lsb1, bias);
|
||||
inp0 = vec_add(inp0, rnd0);
|
||||
inp1 = vec_add(inp1, rnd1);
|
||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel0 =
|
||||
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 =
|
||||
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
inp0 = vec_sel(inp0, nan, sel0);
|
||||
inp1 = vec_sel(inp1, nan, sel1);
|
||||
inp0 = vec_sr(inp0, sh16);
|
||||
@ -435,13 +459,17 @@ inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
#ifdef _ARCH_PWR10
|
||||
__vector signed short ret[4];
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
|
||||
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]);
|
||||
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]);
|
||||
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[0]);
|
||||
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[1]);
|
||||
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[2]);
|
||||
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16(
|
||||
(__vector unsigned char)v.reg.val[3]);
|
||||
reg.val[0] = vec_perm(ret[0], ret[1], omask);
|
||||
reg.val[1] = vec_perm(ret[2], ret[3], omask);
|
||||
#elif defined(_ARCH_PWR9)
|
||||
@ -465,10 +493,14 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inp1 = vec_add(inp1, rnd1);
|
||||
inp2 = vec_add(inp2, rnd2);
|
||||
inp3 = vec_add(inp3, rnd3);
|
||||
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel0 =
|
||||
vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel1 =
|
||||
vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel2 =
|
||||
vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
|
||||
__vector __bool int sel3 =
|
||||
vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
|
||||
inp0 = vec_sel(inp0, nan, sel0);
|
||||
inp1 = vec_sel(inp1, nan, sel1);
|
||||
inp2 = vec_sel(inp2, nan, sel2);
|
||||
@ -482,10 +514,10 @@ inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
#endif
|
||||
}
|
||||
|
||||
inline void prefetch(const void *addr) {
|
||||
inline void prefetch(const void* addr) {
|
||||
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
|
||||
}
|
||||
|
||||
}; // namespace vec_op
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
@ -11,39 +11,40 @@ static_assert(false, "AVX2 must be supported for the current implementation.");
|
||||
|
||||
namespace vec_op {
|
||||
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
||||
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||
|
||||
#ifndef CPU_OP_GUARD
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME)
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#else
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#define CPU_KERNEL_GUARD_IN(NAME) \
|
||||
RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
|
||||
#define CPU_KERNEL_GUARD_OUT(NAME)
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
|
||||
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
|
||||
(f(std::integral_constant<T, indexes>{}), ...);
|
||||
}
|
||||
}; // namespace
|
||||
}; // namespace
|
||||
|
||||
template <typename T, T count, typename F,
|
||||
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
|
||||
constexpr void unroll_loop(F &&f) {
|
||||
constexpr void unroll_loop(F&& f) {
|
||||
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
|
||||
}
|
||||
|
||||
template <typename T> struct Vec {
|
||||
template <typename T>
|
||||
struct Vec {
|
||||
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
|
||||
};
|
||||
|
||||
@ -55,12 +56,12 @@ struct FP16Vec8 : public Vec<FP16Vec8> {
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit FP16Vec8(const void *ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||
explicit FP16Vec8(const void* ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec8(const FP32Vec8 &);
|
||||
explicit FP16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
@ -68,12 +69,12 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
explicit FP16Vec16(const void *ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||
explicit FP16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit FP16Vec16(const FP32Vec16 &);
|
||||
explicit FP16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -87,12 +88,12 @@ struct BF16Vec8 : public Vec<BF16Vec8> {
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit BF16Vec8(const void *ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
|
||||
explicit BF16Vec8(const void* ptr)
|
||||
: reg((__m128i)_mm_loadu_si128((__m128i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec8(const FP32Vec8 &);
|
||||
explicit BF16Vec8(const FP32Vec8&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m128i*>(ptr) = reg; }
|
||||
};
|
||||
|
||||
struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
@ -100,12 +101,12 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
|
||||
__m256i reg;
|
||||
|
||||
explicit BF16Vec16(const void *ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
|
||||
explicit BF16Vec16(const void* ptr)
|
||||
: reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {}
|
||||
|
||||
explicit BF16Vec16(const FP32Vec16 &);
|
||||
explicit BF16Vec16(const FP32Vec16&);
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -120,11 +121,11 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
explicit BF16Vec32(const void* ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
|
||||
|
||||
explicit BF16Vec32(__m512i data) : reg(data) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||
: reg((__m512i)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
|
||||
(__m128i)vec8_data.reg),
|
||||
@ -132,7 +133,7 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
(__m128i)vec8_data.reg, 2),
|
||||
(__m128i)vec8_data.reg, 3)) {}
|
||||
|
||||
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
|
||||
void save(void* ptr) const { *reinterpret_cast<__m512i*>(ptr) = reg; }
|
||||
};
|
||||
#else
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -141,24 +142,24 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
__m256i reg_low;
|
||||
__m256i reg_high;
|
||||
|
||||
explicit BF16Vec32(const void *ptr)
|
||||
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
|
||||
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
|
||||
explicit BF16Vec32(const void* ptr)
|
||||
: reg_low(_mm256_loadu_si256((__m256i const*)ptr)),
|
||||
reg_high(_mm256_loadu_si256((__m256i const*)ptr + 1)) {}
|
||||
|
||||
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
|
||||
reg_high(high) {}
|
||||
explicit BF16Vec32(__m256i low, __m256i high)
|
||||
: reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit BF16Vec32(BF16Vec8 &vec8_data)
|
||||
explicit BF16Vec32(BF16Vec8& vec8_data)
|
||||
: reg_low((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)),
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)),
|
||||
reg_high((__m256i)_mm256_inserti32x4(
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
_mm256_castsi128_si256((__m128i)vec8_data.reg),
|
||||
(__m128i)vec8_data.reg, 1)) {}
|
||||
|
||||
void save(void *ptr) const {
|
||||
*reinterpret_cast<__m256i *>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
|
||||
void save(void* ptr) const {
|
||||
*reinterpret_cast<__m256i*>(ptr) = reg_low;
|
||||
*reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
@ -176,11 +177,11 @@ struct FP32Vec4 : public Vec<FP32Vec4> {
|
||||
|
||||
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec4(const float* ptr) : reg(_mm_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec4(__m128 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {}
|
||||
};
|
||||
|
||||
struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
@ -196,15 +197,15 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
|
||||
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec8(const float* ptr) : reg(_mm256_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec8(__m256 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
|
||||
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {}
|
||||
|
||||
explicit FP32Vec8(const FP16Vec8 &v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||
explicit FP32Vec8(const FP16Vec8& v) : reg(_mm256_cvtph_ps(v.reg)) {}
|
||||
|
||||
explicit FP32Vec8(const BF16Vec8 &v)
|
||||
explicit FP32Vec8(const BF16Vec8& v)
|
||||
: reg(_mm256_castsi256_ps(
|
||||
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
@ -212,7 +213,8 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float result = 0;
|
||||
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&result, &ar](int i) { result += ar.values[i]; });
|
||||
|
||||
return result;
|
||||
}
|
||||
@ -244,27 +246,27 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
erf(ar.values[1]), erf(ar.values[0])));
|
||||
}
|
||||
|
||||
FP32Vec8 operator*(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator*(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator+(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator+(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator-(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator-(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec8 operator/(const FP32Vec8 &b) const {
|
||||
FP32Vec8 operator/(const FP32Vec8& b) const {
|
||||
return FP32Vec8(_mm256_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
void save(float* ptr) const { _mm256_storeu_ps(ptr, reg); }
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct INT32Vec16: public Vec<INT32Vec16> {
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m512i reg;
|
||||
@ -272,12 +274,11 @@ struct INT32Vec16: public Vec<INT32Vec16> {
|
||||
};
|
||||
|
||||
__m512i reg;
|
||||
|
||||
explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
_mm512_storeu_epi32(ptr, reg);
|
||||
}
|
||||
explicit INT32Vec16(const void* data_ptr)
|
||||
: reg(_mm512_loadu_epi32(data_ptr)) {}
|
||||
|
||||
void save(int32_t* ptr) const { _mm512_storeu_epi32(ptr, reg); }
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -301,11 +302,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {}
|
||||
|
||||
explicit FP32Vec16(__m512 data) : reg(data) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg((__m512)_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(
|
||||
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
|
||||
@ -313,36 +314,37 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
(__m128i)data.reg, 2),
|
||||
(__m128i)data.reg, 3)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
explicit FP32Vec16(const FP32Vec8& data)
|
||||
: reg((__m512)_mm512_inserti32x8(
|
||||
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v)
|
||||
explicit FP32Vec16(const BF16Vec16& v)
|
||||
: reg(_mm512_castsi512_ps(
|
||||
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||
explicit FP32Vec16(const FP16Vec16& v) : reg(_mm512_cvtph_ps(v.reg)) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16 &v)
|
||||
: reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {}
|
||||
explicit FP32Vec16(const INT32Vec16& v)
|
||||
: reg(_mm512_cvt_roundepi32_ps(
|
||||
v.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_add_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm512_div_ps(reg, b.reg));
|
||||
}
|
||||
|
||||
@ -370,9 +372,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
|
||||
}
|
||||
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(_mm512_abs_ps(reg));
|
||||
}
|
||||
FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); }
|
||||
|
||||
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
|
||||
|
||||
@ -380,14 +380,15 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
|
||||
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
|
||||
return _mm512_mask_reduce_add_ps(mask, reg);
|
||||
}
|
||||
|
||||
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
void save(float* ptr) const { _mm512_storeu_ps(ptr, reg); }
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -407,32 +408,30 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
__m256 reg_low;
|
||||
__m256 reg_high;
|
||||
|
||||
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
|
||||
reg_high(_mm256_set1_ps(v)) {}
|
||||
explicit FP32Vec16(float v)
|
||||
: reg_low(_mm256_set1_ps(v)), reg_high(_mm256_set1_ps(v)) {}
|
||||
|
||||
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
|
||||
reg_high(_mm256_set1_ps(0.0)) {}
|
||||
explicit FP32Vec16()
|
||||
: reg_low(_mm256_set1_ps(0.0)), reg_high(_mm256_set1_ps(0.0)) {}
|
||||
|
||||
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
|
||||
reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||
explicit FP32Vec16(const float* ptr)
|
||||
: reg_low(_mm256_loadu_ps(ptr)), reg_high(_mm256_loadu_ps(ptr + 8)) {}
|
||||
|
||||
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
|
||||
reg_high(data.reg_high) {}
|
||||
explicit FP32Vec16(const FP32Vec16& data)
|
||||
: reg_low(data.reg_low), reg_high(data.reg_high) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec4 &data)
|
||||
explicit FP32Vec16(const FP32Vec4& data)
|
||||
: reg_low((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)),
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)),
|
||||
reg_high((__m256)_mm256_inserti128_si256(
|
||||
_mm256_castsi128_si256((__m128i)data.reg),
|
||||
(__m128i)data.reg, 1)) {}
|
||||
_mm256_castsi128_si256((__m128i)data.reg), (__m128i)data.reg, 1)) {}
|
||||
|
||||
explicit FP32Vec16(const FP32Vec8 &data)
|
||||
explicit FP32Vec16(const FP32Vec8& data)
|
||||
: reg_low(data.reg), reg_high(data.reg) {}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec16 &v) {
|
||||
explicit FP32Vec16(const FP16Vec16& v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
@ -440,9 +439,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg_high = _mm256_cvtph_ps(high);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const FP16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const FP16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec16 &v) {
|
||||
explicit FP32Vec16(const BF16Vec16& v) {
|
||||
__m128i low = _mm256_extractf128_si256(v.reg, 0);
|
||||
__m128i high = _mm256_extractf128_si256(v.reg, 1);
|
||||
|
||||
@ -456,24 +455,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg_high = _mm256_castsi256_ps(v_high_shifted);
|
||||
}
|
||||
|
||||
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
|
||||
|
||||
FP32Vec16 operator*(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator*(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
|
||||
_mm256_mul_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator+(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
|
||||
_mm256_add_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator-(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator-(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
|
||||
_mm256_sub_ps(reg_high, b.reg_high));
|
||||
}
|
||||
|
||||
FP32Vec16 operator/(const FP32Vec16 &b) const {
|
||||
FP32Vec16 operator/(const FP32Vec16& b) const {
|
||||
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
|
||||
_mm256_div_ps(reg_high, b.reg_high));
|
||||
}
|
||||
@ -484,7 +483,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return low.reduce_sum() + high.reduce_sum();
|
||||
}
|
||||
|
||||
template <int group_size> float reduce_sub_sum(int idx) {
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
float sum = 0.0;
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
|
||||
@ -507,7 +507,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return sum;
|
||||
}
|
||||
|
||||
void save(float *ptr) const {
|
||||
void save(float* ptr) const {
|
||||
_mm256_storeu_ps(ptr, reg_low);
|
||||
_mm256_storeu_ps(ptr + 8, reg_high);
|
||||
}
|
||||
@ -515,7 +515,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512F__
|
||||
struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
__m128i reg;
|
||||
@ -523,14 +523,12 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
};
|
||||
|
||||
__m128i reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) : reg(
|
||||
_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
) {}
|
||||
|
||||
void save(int8_t* ptr) const {
|
||||
_mm_storeu_epi8(ptr, reg);
|
||||
}
|
||||
explicit INT8Vec16(const FP32Vec16& vec)
|
||||
: reg(_mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(
|
||||
vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))) {}
|
||||
|
||||
void save(int8_t* ptr) const { _mm_storeu_epi8(ptr, reg); }
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
constexpr uint32_t M = 0xFFFFFFFF;
|
||||
@ -540,71 +538,92 @@ struct INT8Vec16: public Vec<INT8Vec16> {
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T> struct VecType { using vec_type = void; };
|
||||
template <typename T>
|
||||
struct VecType {
|
||||
using vec_type = void;
|
||||
};
|
||||
|
||||
template <typename T> using vec_t = typename VecType<T>::vec_type;
|
||||
template <typename T>
|
||||
using vec_t = typename VecType<T>::vec_type;
|
||||
|
||||
template <> struct VecType<float> { using vec_type = FP32Vec8; };
|
||||
template <>
|
||||
struct VecType<float> {
|
||||
using vec_type = FP32Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::Half> { using vec_type = FP16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::Half> {
|
||||
using vec_type = FP16Vec8;
|
||||
};
|
||||
|
||||
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
|
||||
template <>
|
||||
struct VecType<c10::BFloat16> {
|
||||
using vec_type = BF16Vec8;
|
||||
};
|
||||
|
||||
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
|
||||
template <typename T>
|
||||
void storeFP32(float v, T* ptr) {
|
||||
*ptr = v;
|
||||
}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
|
||||
inline void fma(FP32Vec16& acc, FP32Vec16& a, FP32Vec16& b) {
|
||||
acc = acc + a * b;
|
||||
}
|
||||
|
||||
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
|
||||
*reinterpret_cast<unsigned short *>(ptr) =
|
||||
template <>
|
||||
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
|
||||
*reinterpret_cast<unsigned short*>(ptr) =
|
||||
_cvtss_sh(v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||
}
|
||||
|
||||
inline FP16Vec8::FP16Vec8(const FP32Vec8 &v)
|
||||
inline FP16Vec8::FP16Vec8(const FP32Vec8& v)
|
||||
: reg(_mm256_cvtps_ph(v.reg,
|
||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm512_cvtps_ph(v.reg,
|
||||
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) {}
|
||||
#else
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16 &v)
|
||||
: reg(_mm256_insertf128_si256(_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg), FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
inline FP16Vec16::FP16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm256_insertf128_si256(
|
||||
_mm256_castsi128_si256(FP16Vec8(FP32Vec8(v.reg_low)).reg),
|
||||
FP16Vec8(FP32Vec8(v.reg_low)).reg, 1)) {}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
*reinterpret_cast<__bfloat16*>(ptr) = _mm_cvtness_sbh(v);
|
||||
}
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
|
||||
|
||||
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
|
||||
inline void fma(FP32Vec16& acc, BF16Vec32& a, BF16Vec32& b) {
|
||||
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
|
||||
}
|
||||
#else
|
||||
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
|
||||
reinterpret_cast<c10::BFloat16 *>(&v);
|
||||
template <>
|
||||
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
|
||||
c10::BFloat16 __attribute__((__may_alias__))* v_ptr =
|
||||
reinterpret_cast<c10::BFloat16*>(&v);
|
||||
*ptr = *(v_ptr + 1);
|
||||
}
|
||||
|
||||
#ifdef __AVX512F__
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
#ifdef __AVX512F__
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(_mm256_cvtepi32_epi16(
|
||||
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
|
||||
: reg(_mm512_cvtepi32_epi16(
|
||||
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
|
||||
#else
|
||||
namespace{
|
||||
#else
|
||||
namespace {
|
||||
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||
__m256i ai = _mm256_castps_si256(a);
|
||||
ai = _mm256_srli_epi32(ai, 16);
|
||||
@ -612,21 +631,21 @@ __m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
|
||||
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
|
||||
return _mm256_extracti128_si256(ai, 0);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
|
||||
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
|
||||
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
|
||||
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
|
||||
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) {
|
||||
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
|
||||
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
|
||||
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BF16__
|
||||
#endif // __AVX512F__
|
||||
#endif // __AVX512BF16__
|
||||
|
||||
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
inline void prefetch(const void* addr) { _mm_prefetch(addr, _MM_HINT_T1); }
|
||||
|
||||
}; // namespace vec_op
|
||||
}; // namespace vec_op
|
||||
|
||||
#endif
|
||||
|
@ -359,7 +359,7 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& b, // [IC, OC], column-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const c10::optional<torch::Tensor>& bias // [OC]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
|
||||
// Checks for conformality
|
||||
@ -442,8 +442,8 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
const torch::Tensor& a_scales, // [1] or [M]
|
||||
const torch::Tensor& b_scales, // [1] or [OC]
|
||||
const torch::Tensor& azp_adj, // [OC]
|
||||
const c10::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const c10::optional<torch::Tensor>& bias // [OC]
|
||||
const std::optional<torch::Tensor>& azp, // [1] or [M]
|
||||
const std::optional<torch::Tensor>& bias // [OC]
|
||||
) {
|
||||
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
|
||||
// Checks for conformality
|
||||
@ -561,7 +561,7 @@ void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
@ -590,7 +590,7 @@ void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale, // [..., 1]
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
@ -9,14 +9,14 @@ std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
|
||||
const torch::Tensor& b, const torch::Tensor& a_scales,
|
||||
const torch::Tensor& b_scales,
|
||||
const torch::Tensor& azp_adj,
|
||||
const c10::optional<torch::Tensor>& azp,
|
||||
const c10::optional<torch::Tensor>& bias);
|
||||
const std::optional<torch::Tensor>& azp,
|
||||
const std::optional<torch::Tensor>& bias);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
}
|
||||
|
||||
|
@ -1,10 +1,22 @@
|
||||
#include <numa.h>
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <sched.h>
|
||||
#ifndef VLLM_NUMA_DISABLED
|
||||
#include <numa.h>
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <sched.h>
|
||||
#endif
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
#ifdef VLLM_NUMA_DISABLED
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
return std::string(
|
||||
"Warning: NUMA is not enabled in this build. `init_cpu_threads_env` has "
|
||||
"no effect to setup thread affinity.");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#ifndef VLLM_NUMA_DISABLED
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
@ -57,7 +69,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
omp_lock_t writelock;
|
||||
omp_init_lock(&writelock);
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
||||
cpu_set_t mask;
|
||||
CPU_ZERO(&mask);
|
||||
@ -88,3 +100,4 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
#endif
|
310
csrc/cumem_allocator.cpp
Normal file
310
csrc/cumem_allocator.cpp
Normal file
@ -0,0 +1,310 @@
|
||||
// A CUDAPluggableAllocator based on cumem* APIs.
|
||||
// Important: allocation size, CUdeviceptr and CUmemGenericAllocationHandle*
|
||||
// need to be unsigned long long
|
||||
#include <iostream>
|
||||
|
||||
extern "C" {
|
||||
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
#include <Python.h>
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#define CUDA_CHECK(condition) \
|
||||
do { \
|
||||
CUresult error = condition; \
|
||||
if (error != 0) { \
|
||||
char* error_string; \
|
||||
cuGetErrorString(error, (const char**)&error_string); \
|
||||
std::cerr << "CUDA Error: " << error_string << " at " << __FILE__ << ":" \
|
||||
<< __LINE__ << std::endl; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Global references to Python callables
|
||||
// NOTE: this is borrowed reference, so we don't need to DECREF them.
|
||||
// This brings the limitation that the allocator needs to be singleton.
|
||||
static PyObject* g_python_malloc_callback = nullptr;
|
||||
static PyObject* g_python_free_callback = nullptr;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper functions:
|
||||
|
||||
void ensure_context(unsigned long long device) {
|
||||
CUcontext pctx;
|
||||
CUDA_CHECK(cuCtxGetCurrent(&pctx));
|
||||
if (!pctx) {
|
||||
// Ensure device context.
|
||||
CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device));
|
||||
CUDA_CHECK(cuCtxSetCurrent(pctx));
|
||||
}
|
||||
}
|
||||
|
||||
void create_and_map(unsigned long long device, ssize_t size, CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
ensure_context(device);
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Allocate memory using cuMemCreate
|
||||
CUDA_CHECK(cuMemCreate(p_memHandle, size, &prop, 0));
|
||||
CUDA_CHECK(cuMemMap(d_mem, size, 0, *p_memHandle, 0));
|
||||
|
||||
CUmemAccessDesc accessDesc = {};
|
||||
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
accessDesc.location.id = device;
|
||||
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
|
||||
CUDA_CHECK(cuMemSetAccess(d_mem, size, &accessDesc, 1));
|
||||
// std::cout << "create_and_map: device=" << device << ", size=" << size << ",
|
||||
// d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
}
|
||||
|
||||
void unmap_and_release(unsigned long long device, ssize_t size,
|
||||
CUdeviceptr d_mem,
|
||||
CUmemGenericAllocationHandle* p_memHandle) {
|
||||
// std::cout << "unmap_and_release: device=" << device << ", size=" << size <<
|
||||
// ", d_mem=" << d_mem << ", p_memHandle=" << p_memHandle << std::endl;
|
||||
ensure_context(device);
|
||||
CUDA_CHECK(cuMemUnmap(d_mem, size));
|
||||
CUDA_CHECK(cuMemRelease(*p_memHandle));
|
||||
}
|
||||
|
||||
PyObject* create_tuple_from_c_integers(unsigned long long a,
|
||||
unsigned long long b,
|
||||
unsigned long long c,
|
||||
unsigned long long d) {
|
||||
// Create a new tuple of size 4
|
||||
PyObject* tuple = PyTuple_New(4);
|
||||
if (!tuple) {
|
||||
return NULL; // Return NULL on failure
|
||||
}
|
||||
|
||||
// Convert integers to Python objects and set them in the tuple
|
||||
PyTuple_SetItem(
|
||||
tuple, 0,
|
||||
PyLong_FromUnsignedLongLong(a)); // Steals reference to the PyLong
|
||||
PyTuple_SetItem(tuple, 1, PyLong_FromUnsignedLongLong(b));
|
||||
PyTuple_SetItem(tuple, 2, PyLong_FromUnsignedLongLong(c));
|
||||
PyTuple_SetItem(tuple, 3, PyLong_FromUnsignedLongLong(d));
|
||||
|
||||
// Note: PyTuple_SetItem "steals" a reference to each object,
|
||||
// so we do not need to Py_DECREF the PyLong objects explicitly.
|
||||
|
||||
return tuple; // Return the created tuple
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Our exported C functions that call Python:
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void* my_malloc(ssize_t size, int device, CUstream stream) {
|
||||
ensure_context(device);
|
||||
|
||||
// first allocation, align the size, and reserve an address, and also allocate
|
||||
// a CUmemGenericAllocationHandle
|
||||
|
||||
// Define memory allocation properties
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = device;
|
||||
prop.allocFlags.compressionType = CU_MEM_ALLOCATION_COMP_NONE;
|
||||
|
||||
// Check if the allocation is supported
|
||||
size_t granularity;
|
||||
CUDA_CHECK(cuMemGetAllocationGranularity(&granularity, &prop,
|
||||
CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
|
||||
size_t alignedSize = ((size + granularity - 1) / granularity) * granularity;
|
||||
|
||||
CUdeviceptr d_mem;
|
||||
CUDA_CHECK(cuMemAddressReserve(&d_mem, alignedSize, 0, 0, 0));
|
||||
|
||||
// allocate the CUmemGenericAllocationHandle
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)malloc(
|
||||
sizeof(CUmemGenericAllocationHandle));
|
||||
|
||||
if (!g_python_malloc_callback) {
|
||||
std::cerr << "ERROR: g_python_malloc_callback not set.\n";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* arg_tuple = create_tuple_from_c_integers(
|
||||
(unsigned long long)device, (unsigned long long)alignedSize,
|
||||
(unsigned long long)d_mem, (unsigned long long)p_memHandle);
|
||||
|
||||
// Call g_python_malloc_callback
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_malloc_callback, arg_tuple, NULL);
|
||||
Py_DECREF(arg_tuple);
|
||||
|
||||
if (!py_result) {
|
||||
PyErr_Print();
|
||||
PyGILState_Release(gstate);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// do the final mapping
|
||||
create_and_map(device, alignedSize, d_mem, p_memHandle);
|
||||
|
||||
return (void*)d_mem;
|
||||
}
|
||||
|
||||
// use CUstream instead of cudaStream_t, to avoid including cuda_runtime_api.h
|
||||
void my_free(void* ptr, ssize_t size, int device, CUstream stream) {
|
||||
// get memory handle from the pointer
|
||||
if (!g_python_free_callback) {
|
||||
std::cerr << "ERROR: g_python_free_callback not set.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
// Acquire GIL (not in stable ABI officially, but often works)
|
||||
PyGILState_STATE gstate = PyGILState_Ensure();
|
||||
|
||||
PyObject* py_ptr =
|
||||
PyLong_FromUnsignedLongLong(reinterpret_cast<unsigned long long>(ptr));
|
||||
|
||||
PyObject* py_result =
|
||||
PyObject_CallFunctionObjArgs(g_python_free_callback, py_ptr, NULL);
|
||||
|
||||
if (!py_result || !PyTuple_Check(py_result) || PyTuple_Size(py_result) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(py_result, "KKKK", &recv_device, &recv_size,
|
||||
&recv_d_mem, &recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return;
|
||||
}
|
||||
|
||||
PyGILState_Release(gstate);
|
||||
|
||||
// recv_size == size
|
||||
// recv_device == device
|
||||
|
||||
// Free memory
|
||||
|
||||
CUdeviceptr d_mem = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
unmap_and_release(device, size, d_mem, p_memHandle);
|
||||
|
||||
// free address and the handle
|
||||
CUDA_CHECK(cuMemAddressFree(d_mem, size));
|
||||
free(p_memHandle);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Python extension boilerplate:
|
||||
|
||||
// Python-exposed function: init_module(python_malloc, python_free)
|
||||
static PyObject* py_init_module(PyObject* self, PyObject* args) {
|
||||
PyObject* malloc_callback = nullptr;
|
||||
PyObject* free_callback = nullptr;
|
||||
|
||||
if (!PyArg_ParseTuple(args, "OO", &malloc_callback, &free_callback)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!PyCallable_Check(malloc_callback) || !PyCallable_Check(free_callback)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Both arguments must be callables");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Save the Python callables
|
||||
// This module does not handle GC of these objects, so they must be kept alive
|
||||
// outside of this module.
|
||||
g_python_malloc_callback = malloc_callback;
|
||||
g_python_free_callback = free_callback;
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_unmap_and_release(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
unmap_and_release(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* python_create_and_map(PyObject* self, PyObject* args) {
|
||||
if (!args || !PyTuple_Check(args) || PyTuple_Size(args) != 4) {
|
||||
PyErr_SetString(PyExc_TypeError, "Expected a tuple of size 4");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
unsigned long long recv_device, recv_size;
|
||||
unsigned long long recv_d_mem, recv_p_memHandle;
|
||||
// Unpack the tuple into four C integers
|
||||
if (!PyArg_ParseTuple(args, "KKKK", &recv_device, &recv_size, &recv_d_mem,
|
||||
&recv_p_memHandle)) {
|
||||
// PyArg_ParseTuple sets an error if it fails
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CUdeviceptr d_mem_ptr = (CUdeviceptr)recv_d_mem;
|
||||
CUmemGenericAllocationHandle* p_memHandle =
|
||||
(CUmemGenericAllocationHandle*)recv_p_memHandle;
|
||||
|
||||
create_and_map(recv_device, recv_size, d_mem_ptr, p_memHandle);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef module_methods[] = {
|
||||
{"init_module", (PyCFunction)py_init_module, METH_VARARGS,
|
||||
"Initialize module with python_malloc and python_free callables."},
|
||||
{"python_create_and_map", (PyCFunction)python_create_and_map, METH_VARARGS,
|
||||
"Create and map memory on the device."},
|
||||
{"python_unmap_and_release", (PyCFunction)python_unmap_and_release,
|
||||
METH_VARARGS, "Unmap and release memory on the device."},
|
||||
{NULL, NULL, 0, NULL} // sentinel
|
||||
};
|
||||
|
||||
static struct PyModuleDef cumem_allocator_module = {
|
||||
PyModuleDef_HEAD_INIT, "cumem_allocator",
|
||||
"cumem-based allocator for CUDAPluggableAllocator", -1, module_methods};
|
||||
|
||||
PyMODINIT_FUNC PyInit_cumem_allocator(void) {
|
||||
// Initialize the module
|
||||
PyObject* module = PyModule_Create(&cumem_allocator_module);
|
||||
if (!module) {
|
||||
return NULL;
|
||||
}
|
||||
return module;
|
||||
}
|
||||
} // extern "C"
|
@ -38,9 +38,13 @@ struct Signal {
|
||||
alignas(128) FlagType peer_counter[2][kMaxBlocks][8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
|
||||
struct __align__(16) RankData {
|
||||
const void* __restrict__ ptrs[8];
|
||||
};
|
||||
|
||||
struct __align__(16) RankSignals { Signal* signals[8]; };
|
||||
struct __align__(16) RankSignals {
|
||||
Signal* signals[8];
|
||||
};
|
||||
|
||||
// like std::array, but aligned
|
||||
template <typename T, int sz>
|
||||
|
@ -27,8 +27,7 @@
|
||||
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
||||
int max_shared_mem_per_block_opt_in = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
||||
device);
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
return max_shared_mem_per_block_opt_in;
|
||||
}
|
||||
|
||||
|
@ -68,7 +68,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@ -301,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
@ -67,7 +67,7 @@ struct ScaledEpilogueBase {
|
||||
// This overload handles the case where there might not be a tensor, in which
|
||||
// case a nullptr is passed and a constant (0) is used.
|
||||
template <typename Descriptor, typename T>
|
||||
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
||||
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
||||
using Arguments = typename Descriptor::Arguments;
|
||||
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
||||
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
||||
@ -223,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
@ -299,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
torch::Tensor const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
||||
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
||||
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
||||
|
@ -97,7 +97,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
|
||||
|
||||
template <typename Stride>
|
||||
static inline auto maybe_make_cute_layout(
|
||||
c10::optional<torch::Tensor> const& tensor,
|
||||
std::optional<torch::Tensor> const& tensor,
|
||||
std::string_view name = "tensor") {
|
||||
using Layout = decltype(make_cute_layout<Stride>(*tensor));
|
||||
|
||||
|
@ -14,9 +14,9 @@ class VLLMDataType(enum.Enum):
|
||||
|
||||
|
||||
class MixedInputKernelScheduleType(enum.Enum):
|
||||
TmaWarpSpecializedMixedInput = enum_auto()
|
||||
TmaWarpSpecializedPingpongMixedInput = enum_auto()
|
||||
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
|
||||
TmaWarpSpecialized = enum_auto()
|
||||
TmaWarpSpecializedPingpong = enum_auto()
|
||||
TmaWarpSpecializedCooperative = enum_auto()
|
||||
|
||||
|
||||
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
|
||||
@ -68,11 +68,11 @@ VLLMKernelScheduleTag: Dict[Union[
|
||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
||||
**KernelScheduleTag, # type: ignore
|
||||
**{
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
||||
"cutlass::gemm::KernelTmaWarpSpecialized",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
||||
}
|
||||
}
|
||||
|
@ -53,12 +53,12 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
@ -93,11 +93,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const c10::optional<at::Tensor> &conv_states,
|
||||
const c10::optional<at::Tensor> &query_start_loc,
|
||||
const c10::optional<at::Tensor> &cache_indices,
|
||||
const c10::optional<at::Tensor> &has_initial_state,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
@ -194,10 +194,10 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const c10::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor> &cache_seqlens_,
|
||||
const c10::optional<at::Tensor> &conv_state_indices_,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
|
@ -402,14 +402,14 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
const torch::Tensor out,
|
||||
const torch::Tensor z,
|
||||
const torch::Tensor out_z,
|
||||
const c10::optional<at::Tensor>& D,
|
||||
const c10::optional<at::Tensor>& delta_bias,
|
||||
const std::optional<at::Tensor>& D,
|
||||
const std::optional<at::Tensor>& delta_bias,
|
||||
const torch::Tensor ssm_states,
|
||||
bool has_z,
|
||||
bool delta_softplus,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool varlen,
|
||||
int64_t pad_slot_id) {
|
||||
|
||||
@ -504,13 +504,13 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms,
|
||||
|
||||
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
|
||||
const c10::optional<torch::Tensor> &D_,
|
||||
const c10::optional<torch::Tensor> &z_,
|
||||
const c10::optional<torch::Tensor> &delta_bias_,
|
||||
const std::optional<torch::Tensor> &D_,
|
||||
const std::optional<torch::Tensor> &z_,
|
||||
const std::optional<torch::Tensor> &delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor> &query_start_loc,
|
||||
const c10::optional<torch::Tensor> &cache_indices,
|
||||
const c10::optional<torch::Tensor> &has_initial_state,
|
||||
const std::optional<torch::Tensor> &query_start_loc,
|
||||
const std::optional<torch::Tensor> &cache_indices,
|
||||
const std::optional<torch::Tensor> &has_initial_state,
|
||||
const torch::Tensor &ssm_states,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
|
@ -138,8 +138,8 @@ __device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
@ -182,8 +182,8 @@ __device__ inline FragB dequant<vllm::kU4.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
|
@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename token_cnts_t>
|
||||
__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
int32_t* sorted_token_ids,
|
||||
int32_t* expert_ids,
|
||||
@ -32,12 +32,10 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
const size_t start_idx = threadIdx.x * tokens_per_thread;
|
||||
|
||||
extern __shared__ int32_t shared_mem[];
|
||||
|
||||
int32_t* tokens_cnts =
|
||||
shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
|
||||
int32_t* cumsum =
|
||||
shared_mem +
|
||||
(blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
|
||||
int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
|
||||
token_cnts_t* tokens_cnts =
|
||||
(token_cnts_t*)(shared_mem + num_experts +
|
||||
1); // 2d tensor with shape (blockDim.x + 1, num_experts)
|
||||
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
|
||||
@ -74,7 +72,7 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
|
||||
block_size) *
|
||||
block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -224,26 +222,46 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
torch::Tensor num_tokens_post_pad) {
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// If we have very large number of experts, we can no longer use shared
|
||||
// memory.
|
||||
// TODO(simon): the right solution should be calculating the exact right
|
||||
// amount of shared memory and use that. The num_experts >= 256 is just a
|
||||
// temporary solution to unblock Deepseek V3.
|
||||
if (num_experts >= 256) {
|
||||
int device_max_shared_mem;
|
||||
auto dev = topk_ids.get_device();
|
||||
cudaDeviceGetAttribute(&device_max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem_i32 =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
|
||||
const int32_t shared_mem_i16 =
|
||||
((num_thread + 1) * num_experts) * sizeof(uint16_t) +
|
||||
(num_experts + 1) * sizeof(int32_t);
|
||||
|
||||
bool use_global_memory = false;
|
||||
bool use_i16 = false; // Use uint16_t for shared memory token counts
|
||||
if (shared_mem_i32 < device_max_shared_mem) {
|
||||
// Do nothing in this case. We're all set to use int32_t token counts
|
||||
} else if (shared_mem_i16 < device_max_shared_mem &&
|
||||
topk_ids.numel() <= 65535) {
|
||||
// when nelements of topk_ids is smaller than 65535 (max value of uint16),
|
||||
// element value of token_cnts would also smaller than 65535,
|
||||
// so we can use uint16 as dtype of token_cnts
|
||||
use_i16 = true;
|
||||
} else {
|
||||
use_global_memory = true;
|
||||
}
|
||||
|
||||
if (use_global_memory) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
|
||||
const int32_t mem_tokens_cnts =
|
||||
((num_experts + 1) * num_experts) * sizeof(int32_t);
|
||||
const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
|
||||
// allocate global memory
|
||||
int32_t* tokens_cnts;
|
||||
int32_t* cumsum;
|
||||
cudaMalloc(&tokens_cnts, mem_tokens_cnts);
|
||||
cudaMalloc(&cumsum, mem_cumsum);
|
||||
auto options_int = torch::TensorOptions()
|
||||
.dtype(torch::kInt)
|
||||
.device(topk_ids.device());
|
||||
torch::Tensor token_cnts_buffer =
|
||||
torch::empty({(num_experts + 1) * num_experts}, options_int);
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
|
||||
@ -252,25 +270,32 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel(), tokens_cnts, cumsum);
|
||||
cudaFree(tokens_cnts);
|
||||
cudaFree(cumsum);
|
||||
topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
|
||||
cumsum_buffer.data_ptr<int32_t>());
|
||||
});
|
||||
} else if (use_i16) {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// set dynamic shared mem
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem_i16));
|
||||
kernel<<<1, num_thread, shared_mem_i16, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
});
|
||||
} else {
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
|
||||
// tensors
|
||||
const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
|
||||
const int32_t shared_mem =
|
||||
((num_thread + 1) * num_experts + (num_experts + 1)) *
|
||||
sizeof(int32_t);
|
||||
|
||||
// set dynamic shared mem
|
||||
auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
auto kernel =
|
||||
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
|
||||
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
|
||||
(void*)kernel, shared_mem));
|
||||
kernel<<<1, num_thread, shared_mem, stream>>>(
|
||||
(void*)kernel, shared_mem_i32));
|
||||
kernel<<<1, num_thread, shared_mem_i32, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
|
58
csrc/ops.h
58
csrc/ops.h
@ -33,9 +33,10 @@ void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
@ -44,9 +45,10 @@ void paged_attention_v2(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
@ -86,6 +88,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
|
||||
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
|
||||
@ -153,15 +157,15 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
|
||||
|
||||
@ -169,7 +173,7 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& e,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||
torch::Tensor& e, torch::Tensor const& a);
|
||||
@ -177,11 +181,11 @@ bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale,
|
||||
c10::optional<torch::Tensor> const& azp);
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales,
|
||||
c10::optional<torch::Tensor> const& azp);
|
||||
std::optional<torch::Tensor> const& azp);
|
||||
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
@ -198,34 +202,34 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& scale_ub);
|
||||
std::optional<torch::Tensor> const& scale_ub);
|
||||
|
||||
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const torch::Tensor& A, const torch::Tensor& B,
|
||||
const torch::Tensor& C,
|
||||
const c10::optional<torch::Tensor>& D_,
|
||||
const c10::optional<torch::Tensor>& z_,
|
||||
const c10::optional<torch::Tensor>& delta_bias_,
|
||||
const std::optional<torch::Tensor>& D_,
|
||||
const std::optional<torch::Tensor>& z_,
|
||||
const std::optional<torch::Tensor>& delta_bias_,
|
||||
bool delta_softplus,
|
||||
const c10::optional<torch::Tensor>& query_start_loc,
|
||||
const c10::optional<torch::Tensor>& cache_indices,
|
||||
const c10::optional<torch::Tensor>& has_initial_state,
|
||||
const std::optional<torch::Tensor>& query_start_loc,
|
||||
const std::optional<torch::Tensor>& cache_indices,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const c10::optional<at::Tensor>& cache_seqlens_,
|
||||
const c10::optional<at::Tensor>& conv_state_indices_,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias_,
|
||||
const c10::optional<at::Tensor>& conv_states,
|
||||
const c10::optional<at::Tensor>& query_start_loc,
|
||||
const c10::optional<at::Tensor>& cache_indices,
|
||||
const c10::optional<at::Tensor>& has_initial_state,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
|
||||
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr, int64_t const block_tables_stride,
|
||||
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
|
||||
int const n_pad = num_seqs - num_queries;
|
||||
if (n_pad && blockIdx.x == 0) {
|
||||
// Handle cuda graph padding
|
||||
int const offset = num_queries;
|
||||
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||
input_tokens_ptr[offset + i] = 0;
|
||||
input_positions_ptr[offset + i] = 0;
|
||||
slot_mapping_ptr[offset + i] = -1;
|
||||
}
|
||||
}
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x < num_query_blocks) {
|
||||
|
@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& scale,
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
std::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
|
||||
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
|
@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
|
@ -51,7 +51,7 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
if (bias) {
|
||||
@ -70,8 +70,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
|
@ -9,26 +9,26 @@ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
|
||||
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -36,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
@ -61,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
@ -84,7 +84,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
@ -148,8 +148,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& azp_adj,
|
||||
c10::optional<torch::Tensor> const& azp,
|
||||
c10::optional<torch::Tensor> const& bias) {
|
||||
std::optional<torch::Tensor> const& azp,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
// Checks for conformality
|
||||
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
||||
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
||||
|
@ -173,8 +173,8 @@ dequant<half, vllm::kU4B8.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
@ -197,9 +197,9 @@ dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
@ -221,8 +221,8 @@ dequant<half, vllm::kU4.id()>(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
@ -244,9 +244,9 @@ dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
@ -834,6 +834,7 @@ __global__ void Marlin(
|
||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
int4* sh_red = sh_s + (stages * s_sh_stage);
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
@ -932,11 +933,11 @@ __global__ void Marlin(
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
}
|
||||
// Only fetch scales if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (s_sh_wr_pred) {
|
||||
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
||||
}
|
||||
if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
|
||||
s_gl_rd += s_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
@ -1038,9 +1039,7 @@ __global__ void Marlin(
|
||||
// No act-order case
|
||||
if constexpr (group_blocks != -1) {
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_s_stage =
|
||||
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
||||
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
@ -1339,15 +1338,15 @@ __global__ void Marlin(
|
||||
int red_sh_wr =
|
||||
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
||||
if (i < red_off) {
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
|
||||
float* c_rd = reinterpret_cast<float*>(
|
||||
&sh_red[red_sh_delta * j + red_sh_rd]);
|
||||
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < 4; k++)
|
||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
||||
c_rd[k] + c_wr[k];
|
||||
}
|
||||
sh[red_sh_wr] =
|
||||
sh_red[red_sh_wr] =
|
||||
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
||||
}
|
||||
}
|
||||
@ -1357,7 +1356,7 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4 * 2; i++) {
|
||||
float* c_rd =
|
||||
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
|
||||
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
||||
@ -1397,7 +1396,7 @@ __global__ void Marlin(
|
||||
#pragma unroll
|
||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||
cp_async4_pred(
|
||||
&sh[c_sh_wr + c_sh_wr_delta * i],
|
||||
&sh_red[c_sh_wr + c_sh_wr_delta * i],
|
||||
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
||||
c_gl_wr_delta_i * (i % 2)],
|
||||
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
||||
@ -1410,7 +1409,7 @@ __global__ void Marlin(
|
||||
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
||||
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
||||
if (!first) {
|
||||
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
|
||||
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2 * 4; j++) {
|
||||
reinterpret_cast<float*>(
|
||||
@ -1461,10 +1460,10 @@ __global__ void Marlin(
|
||||
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
sh[threadIdx.x] =
|
||||
sh_red[threadIdx.x] =
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
||||
|
||||
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
|
||||
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
|
||||
#pragma unroll
|
||||
for (int f = 0; f < 4; f++) {
|
||||
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
||||
@ -1515,7 +1514,7 @@ __global__ void Marlin(
|
||||
res = __hmul2(res, s[0]);
|
||||
}
|
||||
|
||||
((scalar_t2*)sh)[idx] = res;
|
||||
((scalar_t2*)sh_red)[idx] = res;
|
||||
};
|
||||
|
||||
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
||||
@ -1543,7 +1542,7 @@ __global__ void Marlin(
|
||||
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
||||
i++) {
|
||||
if (c_gl_wr < c_gl_wr_end) {
|
||||
C[c_gl_wr] = sh[c_sh_rd];
|
||||
C[c_gl_wr] = sh_red[c_sh_rd];
|
||||
c_gl_wr += c_gl_wr_delta;
|
||||
c_sh_rd += c_sh_rd_delta;
|
||||
}
|
||||
@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
||||
|
||||
float pipe_size = (a_size + b_size) * pipe_stages;
|
||||
|
||||
float reduce_size = max(th_config.num_threads * 32 * 4,
|
||||
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
|
||||
|
||||
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
||||
|
||||
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
||||
return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
|
||||
}
|
||||
|
||||
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
||||
|
@ -63,7 +63,7 @@ torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) {
|
||||
|
||||
|
||||
static inline std::optional<at::ScalarType> maybe_scalartype(
|
||||
c10::optional<at::Tensor> const& t) {
|
||||
std::optional<at::Tensor> const& t) {
|
||||
if (!t) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
@ -189,7 +189,7 @@ using Kernel_{{type_sig}} = MacheteKernelTemplate<
|
||||
{{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT
|
||||
{{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT
|
||||
{{DataTypeTag[t.a_token_scale]}}, // TokenScaleT
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative,
|
||||
Sch>;
|
||||
|
||||
{% for sch in schs %}
|
||||
@ -223,7 +223,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
{{DataTypeTag[t.convert]}}, // ElementConvert
|
||||
{{DataTypeTag[t.accumulator]}}, // Accumulator
|
||||
cutlass::layout::ColumnMajor,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative>
|
||||
>(args.B);
|
||||
}
|
||||
{%- endfor %}
|
||||
@ -239,7 +239,7 @@ torch::Tensor prepack_B_dispatch(PrepackBArgs args) {
|
||||
}; // namespace machete
|
||||
"""
|
||||
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
|
||||
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperative
|
||||
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
|
||||
|
||||
|
||||
@ -300,7 +300,7 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
# mostly unique shorter sch_sig
|
||||
def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
||||
kernel_terse_names_replace = {
|
||||
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
|
||||
"KernelTmaWarpSpecializedCooperative": "TmaMI_",
|
||||
"TmaWarpSpecializedCooperative_": "TmaCoop_",
|
||||
"StreamKScheduler": "streamK",
|
||||
}
|
||||
|
@ -18,16 +18,14 @@ struct VLLMCollectiveBuilder<
|
||||
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
|
||||
KernelScheduleType,
|
||||
cute::enable_if_t<(
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
|
||||
KernelTmaWarpSpecializedCooperative>)>> {
|
||||
using CollectiveOp = machete::MacheteCollectiveMma<
|
||||
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
|
||||
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
|
||||
StageCountType, KernelScheduleType>;
|
||||
};
|
||||
|
||||
}; // namespace cutlass::gemm::collective
|
||||
}; // namespace cutlass::gemm::collective
|
||||
|
@ -66,13 +66,11 @@ struct MacheteCollectiveMma {
|
||||
using Schedule = KernelScheduleType;
|
||||
static_assert(
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedPingpongMixedInput> ||
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
|
||||
cute::is_same_v<Schedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
|
||||
"KernelSchedule must be one of the warp specialized policies");
|
||||
|
||||
public:
|
||||
@ -113,8 +111,7 @@ struct MacheteCollectiveMma {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelScheduleType,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
|
@ -183,11 +183,11 @@ struct MacheteKernelTemplate {
|
||||
torch::Tensor const& A, // MxK matrix
|
||||
torch::Tensor const& B, // KxN prepacked matrix
|
||||
torch::Tensor& D, // MxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
c10::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||
c10::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||
std::optional<torch::Tensor> const& maybe_g_scales, // scale_KxN matrix
|
||||
std::optional<torch::Tensor> const& maybe_g_zeros, // scale_KxN matrix
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<torch::Tensor> const& maybe_ch_scales, // len N vector
|
||||
std::optional<torch::Tensor> const& maybe_tok_scales) // len M vector
|
||||
{
|
||||
static_assert(!with_group_zeropoints || with_group_scales);
|
||||
|
||||
|
@ -13,23 +13,23 @@ struct MMArgs {
|
||||
torch::Tensor const& A;
|
||||
torch::Tensor const& B;
|
||||
vllm::ScalarType const& b_type;
|
||||
c10::optional<at::ScalarType> const& maybe_out_type;
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros;
|
||||
c10::optional<int64_t> maybe_group_size;
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales;
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales;
|
||||
c10::optional<std::string> maybe_schedule;
|
||||
std::optional<at::ScalarType> const& maybe_out_type;
|
||||
std::optional<torch::Tensor> const& maybe_group_scales;
|
||||
std::optional<torch::Tensor> const& maybe_group_zeros;
|
||||
std::optional<int64_t> maybe_group_size;
|
||||
std::optional<torch::Tensor> const& maybe_channel_scales;
|
||||
std::optional<torch::Tensor> const& maybe_token_scales;
|
||||
std::optional<std::string> maybe_schedule;
|
||||
};
|
||||
|
||||
struct SupportedSchedulesArgs {
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type;
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type;
|
||||
c10::optional<at::ScalarType> maybe_out_type;
|
||||
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||
std::optional<at::ScalarType> maybe_group_zeros_type;
|
||||
std::optional<at::ScalarType> maybe_channel_scales_type;
|
||||
std::optional<at::ScalarType> maybe_token_scales_type;
|
||||
std::optional<at::ScalarType> maybe_out_type;
|
||||
};
|
||||
|
||||
torch::Tensor mm_dispatch(MMArgs args);
|
||||
|
@ -10,7 +10,7 @@ struct PrepackBArgs {
|
||||
torch::Tensor const& B;
|
||||
at::ScalarType a_type;
|
||||
vllm::ScalarType b_type;
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type;
|
||||
std::optional<at::ScalarType> maybe_group_scales_type;
|
||||
};
|
||||
|
||||
template <typename PrepackedLayoutB>
|
||||
|
@ -98,8 +98,7 @@ struct PrepackedLayoutBTemplate {
|
||||
// For coop schedules we have two warp groups cooperatively issuing wgmma
|
||||
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
|
||||
using AtomLayoutMNK = cute::conditional_t<
|
||||
cute::is_same_v<KernelSchedule,
|
||||
KernelTmaWarpSpecializedCooperativeMixedInput>,
|
||||
cute::is_same_v<KernelSchedule, KernelTmaWarpSpecializedCooperative>,
|
||||
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
|
||||
|
||||
using TiledMma = decltype(cute::make_tiled_mma(
|
||||
@ -247,4 +246,4 @@ struct PrepackedLayoutBTemplate {
|
||||
}
|
||||
};
|
||||
|
||||
}; // namespace machete
|
||||
}; // namespace machete
|
||||
|
@ -10,11 +10,11 @@ using namespace vllm;
|
||||
|
||||
std::vector<std::string> supported_schedules(
|
||||
at::ScalarType a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> maybe_group_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_group_zeros_type,
|
||||
c10::optional<at::ScalarType> maybe_channel_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_token_scales_type,
|
||||
c10::optional<at::ScalarType> maybe_out_type) {
|
||||
std::optional<at::ScalarType> maybe_group_scales_type,
|
||||
std::optional<at::ScalarType> maybe_group_zeros_type,
|
||||
std::optional<at::ScalarType> maybe_channel_scales_type,
|
||||
std::optional<at::ScalarType> maybe_token_scales_type,
|
||||
std::optional<at::ScalarType> maybe_out_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return supported_schedules_dispatch({
|
||||
.a_type = a_type,
|
||||
@ -29,13 +29,13 @@ std::vector<std::string> supported_schedules(
|
||||
|
||||
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_out_type,
|
||||
c10::optional<torch::Tensor> const& maybe_group_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_group_zeros,
|
||||
c10::optional<int64_t> maybe_group_size,
|
||||
c10::optional<torch::Tensor> const& maybe_channel_scales,
|
||||
c10::optional<torch::Tensor> const& maybe_token_scales,
|
||||
c10::optional<std::string> maybe_schedule) {
|
||||
std::optional<at::ScalarType> const& maybe_out_type,
|
||||
std::optional<torch::Tensor> const& maybe_group_scales,
|
||||
std::optional<torch::Tensor> const& maybe_group_zeros,
|
||||
std::optional<int64_t> maybe_group_size,
|
||||
std::optional<torch::Tensor> const& maybe_channel_scales,
|
||||
std::optional<torch::Tensor> const& maybe_token_scales,
|
||||
std::optional<std::string> maybe_schedule) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return mm_dispatch({.A = A,
|
||||
.B = B,
|
||||
@ -51,7 +51,7 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
|
||||
|
||||
torch::Tensor prepack_B(
|
||||
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
|
||||
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||
std::optional<at::ScalarType> const& maybe_group_scales_type) {
|
||||
ScalarType const b_type = ScalarType::from_id(b_type_id);
|
||||
return prepack_B_dispatch(
|
||||
{.B = B,
|
||||
|
@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) {
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
const int SUB = 0x64086408;
|
||||
|
@ -141,8 +141,8 @@ __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) {
|
||||
static constexpr uint32_t HI = 0x00f000f0;
|
||||
static constexpr uint32_t EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
uint32_t t0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
|
||||
uint32_t t1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
|
||||
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||
// directly into `SUB` and `ADD`.
|
||||
static constexpr uint32_t SUB = 0x64086408;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user