mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
400 Commits
v0.11.1rc0
...
skip-lmfe-
Author | SHA1 | Date | |
---|---|---|---|
37d0a00b16 | |||
e94cfd51da | |||
7c12763b24 | |||
3b780a4bbb | |||
30f78af147 | |||
19a9b169bf | |||
96ad65b7fe | |||
8d2b8c0ff2 | |||
b2155ed317 | |||
910abdbd08 | |||
cddce79fda | |||
e519281920 | |||
7b03584de8 | |||
ae9d0e7da5 | |||
0e67102d93 | |||
f4ba2061cf | |||
1e6848a65d | |||
67661375fa | |||
213b64452a | |||
784c231151 | |||
606b00e80f | |||
720d3cd0f0 | |||
ab196edefb | |||
3ee202ea1e | |||
ad430a67ca | |||
6f0f570c43 | |||
b545a0b207 | |||
29255cfc3b | |||
da4455609d | |||
aafb99a4d4 | |||
757fa4a4da | |||
c6187f55f7 | |||
8983e0216f | |||
1ee35382cb | |||
6e783bc54b | |||
c9d33c60dc | |||
2e54db4d2b | |||
44f633dba1 | |||
a462331e36 | |||
4069db3f2e | |||
0d37450eb7 | |||
47e66c24e2 | |||
3b736e1c38 | |||
2c1c7dfb35 | |||
e246ad6f0c | |||
5728da11ea | |||
92be3f3517 | |||
d1ddf340c8 | |||
ec10fd0abc | |||
0426e3c5e1 | |||
4bdf7ac593 | |||
dc7976dd9f | |||
e4791438ed | |||
e6e898f95d | |||
ddcbc2f334 | |||
a83ff278d6 | |||
cf4cd6c24f | |||
b960441812 | |||
1317028aa8 | |||
5e49c3e777 | |||
0d7c3cb51d | |||
1b2c440cd6 | |||
0f29dca988 | |||
d24cf322e1 | |||
d17f0fbf30 | |||
43ab8cfaa5 | |||
de253d63b7 | |||
8bd696fa53 | |||
bb6d8c21f9 | |||
ebf6ef1a9b | |||
0c52d6ef81 | |||
467a4f98f1 | |||
e614ab7806 | |||
2a03f93de9 | |||
da364615fc | |||
f08919b7d1 | |||
93f2c0aa08 | |||
4ebc9108a7 | |||
e1ba235668 | |||
b82f4307c9 | |||
76879cc160 | |||
b25d7b5657 | |||
e09d1753ec | |||
4ba8875749 | |||
6273fe8d3d | |||
9fb3ae4e6f | |||
76afe4edf8 | |||
c1b06fc182 | |||
241b4cfe66 | |||
9fc983c707 | |||
2f99f2f506 | |||
338b1bf04f | |||
e39dc46f8f | |||
10c75b5439 | |||
f9582fd8f4 | |||
f377333bd7 | |||
f8607863d8 | |||
335b28f7d1 | |||
5e65d6b2ad | |||
0d4f48fa10 | |||
127c8b782a | |||
cd9890544b | |||
067da2d1df | |||
046118b938 | |||
b32260ab85 | |||
f80e7866c0 | |||
31a4b3e6c4 | |||
caf8b1c084 | |||
1b86bd8e18 | |||
59012df99b | |||
3d1f67616d | |||
6ebaf43ee4 | |||
0c824fc46f | |||
eb577e4655 | |||
8f36850f73 | |||
29fd2662ba | |||
30a3e5af69 | |||
a38c1bfe09 | |||
320feae6f5 | |||
1e4ecca1d0 | |||
c0a7b89d8e | |||
6f59beaf0b | |||
41f1cf38f2 | |||
08d26a1b7e | |||
63773a6200 | |||
883b42896a | |||
e1098ced95 | |||
d100d78eb3 | |||
7e4cd070b0 | |||
46b0779996 | |||
de342585ff | |||
185d8ed44f | |||
d9836d4517 | |||
5f7e8a916a | |||
4dbdf4a294 | |||
c6873c4e6d | |||
2111b4643c | |||
c50901f3b9 | |||
8229280a9c | |||
f77df94647 | |||
f231e5bc21 | |||
2161efe978 | |||
f23b4c04fd | |||
93540958b8 | |||
44b9af5bb2 | |||
7cd95dc8a3 | |||
c02058c222 | |||
b2ea5ba677 | |||
824a3f403f | |||
05f6846ede | |||
20db99cc69 | |||
6431be808f | |||
4727a8afa7 | |||
b8f603cebe | |||
fc679696f8 | |||
ab5e7d93f4 | |||
0340f45553 | |||
19a00eb210 | |||
391612e78b | |||
77c95f72f7 | |||
59f30d0448 | |||
43c146ca42 | |||
7c2ec0fe87 | |||
039b6bade3 | |||
6c04638214 | |||
91ac7f764d | |||
4be7d7c1c9 | |||
59b477645c | |||
778f554157 | |||
d3c84297c3 | |||
f509a20846 | |||
60bc25e74c | |||
b893d661b1 | |||
6b6e98775f | |||
9c3c21c519 | |||
512b8affa4 | |||
1c0c68202c | |||
5f317530ec | |||
557b2e961d | |||
4e256cadc2 | |||
d6953beb91 | |||
17edd8a807 | |||
3303cfb4ac | |||
b7e8e4e6be | |||
432e1cbc23 | |||
201c971e96 | |||
e0986ea07b | |||
a964e5e6c3 | |||
78c1d5bfd2 | |||
59a85c366e | |||
119f00630b | |||
a42d2df75f | |||
5c057e068f | |||
ed3aeb25a4 | |||
86ee949128 | |||
4570535ec4 | |||
2a6dc67eb5 | |||
f05fea1f5e | |||
d0df145c2a | |||
1838cd4860 | |||
7d6b03381e | |||
7c2e91c4e0 | |||
736fbf4c89 | |||
44ea85137a | |||
d3d649efec | |||
ea507c3a93 | |||
9705fba7b7 | |||
2f7dbc9b42 | |||
ea25a76c05 | |||
67bc0c003e | |||
5a05f26603 | |||
7ef40bb983 | |||
767cbb011d | |||
7cfa4b24bf | |||
b71fcd4905 | |||
75003f34e8 | |||
78b8015a4d | |||
831b124151 | |||
c1ffcb55da | |||
0879736aab | |||
a26917332f | |||
cd9e5b8340 | |||
300a59c4c3 | |||
d76541a6c5 | |||
dd96465fd7 | |||
4f8f47e87e | |||
d78fda7cda | |||
73a99cc2a5 | |||
adae0c1f43 | |||
cbf9221992 | |||
5f42fc53b6 | |||
8ee846c27c | |||
812b7f54a8 | |||
5f2cacdb1e | |||
aa5053e3fe | |||
79aa244678 | |||
2ed3f20dba | |||
48f309029a | |||
0e93ac0b3a | |||
5446ad1d24 | |||
f9a8084e48 | |||
3e70e3d4d5 | |||
eb0fa43868 | |||
0ad9951c41 | |||
8c9117181d | |||
c4b48d3c0f | |||
10d765482d | |||
39b643dc1a | |||
711f485643 | |||
9c5ee91b2a | |||
27edd2aeb4 | |||
e5017cd6d6 | |||
6a7796e871 | |||
47b9339546 | |||
5d5146eee3 | |||
2aaa423842 | |||
ad2d788016 | |||
36ce76c632 | |||
f1fc2107a3 | |||
13cdc02173 | |||
502640c3f9 | |||
3d5f1c8640 | |||
1cab2f9cad | |||
1e50f1be70 | |||
ad87ba927a | |||
decf7f794b | |||
d00d652998 | |||
3b279a84be | |||
5e4a8223c6 | |||
e51de388a2 | |||
cc253b73d3 | |||
7d6fb905d9 | |||
418d111f8c | |||
be8921fbba | |||
d4e7a1152d | |||
be22bb6f3d | |||
169313b9f8 | |||
0b018d8baf | |||
c31246800c | |||
4134312b35 | |||
da554f932e | |||
aac622e0cd | |||
1726e93ef1 | |||
ee04c0cd04 | |||
c36f0aa300 | |||
5234dc7451 | |||
3b7c20a6b5 | |||
f9e714813a | |||
2518230d3e | |||
a332b84578 | |||
1405f0c7ba | |||
84d57342b6 | |||
57b46d769e | |||
f48b6a03ba | |||
2a69ab4899 | |||
8d7da92fd7 | |||
e952eee698 | |||
66bca9b8bd | |||
99028fda44 | |||
1244948885 | |||
a73f6491c8 | |||
001e50c92c | |||
96ebcaa3ad | |||
5db1870bb9 | |||
2ce26b9b5d | |||
a388252ac4 | |||
9a9f48dff7 | |||
67f3fb0844 | |||
43b752c325 | |||
cfd302db9b | |||
fb610ae684 | |||
2f652e6cdf | |||
e6a226efba | |||
a2e6fa7e03 | |||
9f1c4ecaf2 | |||
ef283548f7 | |||
f4db5e6de1 | |||
099aaee536 | |||
35fe398c7c | |||
bb6d43047e | |||
bc546f76a1 | |||
80608ba5af | |||
e184c9c510 | |||
d7e34b4210 | |||
ef6e0e7132 | |||
1ad3aca682 | |||
8d0afa9b42 | |||
fa7e254a7f | |||
e23cacda35 | |||
2e1b8bc2b6 | |||
e47433b3c1 | |||
23194d83e8 | |||
61aedb5ffe | |||
d3bd171123 | |||
89e4050af4 | |||
78a47f87ce | |||
6a113d9aed | |||
2e4fe48c37 | |||
8eb0a1d906 | |||
fea3e476aa | |||
61a3431613 | |||
9bedac9623 | |||
c42ff4f4fd | |||
d5ab28511c | |||
e61eb5e09d | |||
0899ba5b42 | |||
145ac73317 | |||
d0d138bc55 | |||
43227236ec | |||
8616300ae2 | |||
edbaadd91f | |||
9360d34fa1 | |||
1b67b04656 | |||
bd51f78e39 | |||
65ecb4f134 | |||
143844fa43 | |||
219cfbe7f6 | |||
9b44a7d926 | |||
a3ae45a38c | |||
0307428d65 | |||
471997adf6 | |||
b1ded114b9 | |||
f4e4088c99 | |||
0efd540dbc | |||
6144754014 | |||
69311446ba | |||
da63274d9f | |||
c216119d64 | |||
5546acb463 | |||
c0ec81836f | |||
b65e56babe | |||
49996cd597 | |||
ecb37e276a | |||
a5354b3ed2 | |||
f9df8b4ad7 | |||
ec152c8748 | |||
7977e5027c | |||
3f5d902d2a | |||
27d7638b94 | |||
176173989a | |||
23b8ee672d | |||
3939152069 | |||
cd87bfbf37 | |||
b3613e3ace | |||
d346ec695e | |||
c242c98031 | |||
f1d53d150c | |||
92da847cf5 | |||
3958b96bf5 | |||
8bf8f45822 | |||
6f5c0931c1 | |||
4e33a7ea85 | |||
dc48ba0c75 | |||
4778b42660 | |||
c70ac4b8ff | |||
cf89202855 | |||
f075693da7 | |||
f708bd4904 | |||
0002b7f0d1 | |||
11aafd9886 |
@ -368,7 +368,7 @@ if __name__ == "__main__":
|
|||||||
# The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
|
# The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...",
|
||||||
# we want to turn it into "8xGPUTYPE"
|
# we want to turn it into "8xGPUTYPE"
|
||||||
df["GPU"] = df["GPU"].apply(
|
df["GPU"] = df["GPU"].apply(
|
||||||
lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}"
|
lambda x: f"{len(x.splitlines())}x{x.splitlines()[0]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# get markdown tables
|
# get markdown tables
|
||||||
|
@ -181,18 +181,14 @@ launch_vllm_server() {
|
|||||||
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
|
if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then
|
||||||
echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience."
|
echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience."
|
||||||
model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model')
|
model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model')
|
||||||
server_command="python3 \
|
server_command="vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
-tp $tp \
|
-tp $tp \
|
||||||
--model $model \
|
|
||||||
--port $port \
|
--port $port \
|
||||||
$server_args"
|
$server_args"
|
||||||
else
|
else
|
||||||
echo "Key 'fp8' does not exist in common params."
|
echo "Key 'fp8' does not exist in common params."
|
||||||
server_command="python3 \
|
server_command="vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
-tp $tp \
|
-tp $tp \
|
||||||
--model $model \
|
|
||||||
--port $port \
|
--port $port \
|
||||||
$server_args"
|
$server_args"
|
||||||
fi
|
fi
|
||||||
|
@ -365,8 +365,7 @@ run_serving_tests() {
|
|||||||
continue
|
continue
|
||||||
fi
|
fi
|
||||||
|
|
||||||
server_command="$server_envs python3 \
|
server_command="$server_envs vllm serve \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
$server_args"
|
$server_args"
|
||||||
|
|
||||||
# run the server
|
# run the server
|
||||||
@ -455,11 +454,6 @@ main() {
|
|||||||
fi
|
fi
|
||||||
check_hf_token
|
check_hf_token
|
||||||
|
|
||||||
# Set to v1 to run v1 benchmark
|
|
||||||
if [[ "${ENGINE_VERSION:-v0}" == "v1" ]]; then
|
|
||||||
export VLLM_USE_V1=1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# dependencies
|
# dependencies
|
||||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
(which jq) || (apt-get update && apt-get -y install jq)
|
(which jq) || (apt-get update && apt-get -y install jq)
|
||||||
|
@ -1,46 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
@ -48,7 +48,7 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||||
@ -76,7 +76,7 @@ steps:
|
|||||||
queue: arm64_cpu_queue_postmerge
|
queue: arm64_cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg torch_cuda_arch_list='8.7 9.0 10.0+PTX 12.0' --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m) --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
|
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-$(uname -m)"
|
||||||
|
|
||||||
# Add job to create multi-arch manifest
|
# Add job to create multi-arch manifest
|
||||||
@ -150,11 +150,16 @@ steps:
|
|||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||||
- "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
- "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64"
|
||||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly"
|
- "docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64"
|
||||||
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-x86_64 vllm/vllm-openai:nightly-x86_64"
|
||||||
- "docker push vllm/vllm-openai:nightly"
|
- "docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT-aarch64 vllm/vllm-openai:nightly-aarch64"
|
||||||
- "docker push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
- "docker push vllm/vllm-openai:nightly-x86_64"
|
||||||
|
- "docker push vllm/vllm-openai:nightly-aarch64"
|
||||||
|
- "docker manifest create vllm/vllm-openai:nightly vllm/vllm-openai:nightly-x86_64 vllm/vllm-openai:nightly-aarch64 --amend"
|
||||||
|
- "docker manifest create vllm/vllm-openai:nightly-$BUILDKITE_COMMIT vllm/vllm-openai:nightly-x86_64 vllm/vllm-openai:nightly-aarch64 --amend"
|
||||||
|
- "docker manifest push vllm/vllm-openai:nightly"
|
||||||
|
- "docker manifest push vllm/vllm-openai:nightly-$BUILDKITE_COMMIT"
|
||||||
# Clean up old nightly builds (keep only last 14)
|
# Clean up old nightly builds (keep only last 14)
|
||||||
- "bash .buildkite/scripts/cleanup-nightly-builds.sh"
|
- "bash .buildkite/scripts/cleanup-nightly-builds.sh"
|
||||||
plugins:
|
plugins:
|
||||||
@ -163,3 +168,4 @@ steps:
|
|||||||
password-env: DOCKERHUB_TOKEN
|
password-env: DOCKERHUB_TOKEN
|
||||||
env:
|
env:
|
||||||
DOCKER_BUILDKIT: "1"
|
DOCKER_BUILDKIT: "1"
|
||||||
|
DOCKERHUB_USERNAME: "vllmbot"
|
||||||
|
@ -8,20 +8,41 @@ set -ex
|
|||||||
# DockerHub API endpoint for vllm/vllm-openai repository
|
# DockerHub API endpoint for vllm/vllm-openai repository
|
||||||
REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags"
|
REPO_API_URL="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags"
|
||||||
|
|
||||||
# Get DockerHub token from environment
|
# Get DockerHub credentials from environment
|
||||||
if [ -z "$DOCKERHUB_TOKEN" ]; then
|
if [ -z "$DOCKERHUB_TOKEN" ]; then
|
||||||
echo "Error: DOCKERHUB_TOKEN environment variable is not set"
|
echo "Error: DOCKERHUB_TOKEN environment variable is not set"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ -z "$DOCKERHUB_USERNAME" ]; then
|
||||||
|
echo "Error: DOCKERHUB_USERNAME environment variable is not set"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Get DockerHub bearer token
|
||||||
|
echo "Getting DockerHub bearer token..."
|
||||||
|
set +x
|
||||||
|
BEARER_TOKEN=$(curl -s -X POST \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d "{\"username\": \"$DOCKERHUB_USERNAME\", \"password\": \"$DOCKERHUB_TOKEN\"}" \
|
||||||
|
"https://hub.docker.com/v2/users/login" | jq -r '.token')
|
||||||
|
set -x
|
||||||
|
|
||||||
|
if [ -z "$BEARER_TOKEN" ] || [ "$BEARER_TOKEN" = "null" ]; then
|
||||||
|
echo "Error: Failed to get DockerHub bearer token"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
# Function to get all tags from DockerHub
|
# Function to get all tags from DockerHub
|
||||||
get_all_tags() {
|
get_all_tags() {
|
||||||
local page=1
|
local page=1
|
||||||
local all_tags=""
|
local all_tags=""
|
||||||
|
|
||||||
while true; do
|
while true; do
|
||||||
local response=$(curl -s -H "Authorization: Bearer $DOCKERHUB_TOKEN" \
|
set +x
|
||||||
|
local response=$(curl -s -H "Authorization: Bearer $BEARER_TOKEN" \
|
||||||
"$REPO_API_URL?page=$page&page_size=100")
|
"$REPO_API_URL?page=$page&page_size=100")
|
||||||
|
set -x
|
||||||
|
|
||||||
# Get both last_updated timestamp and tag name, separated by |
|
# Get both last_updated timestamp and tag name, separated by |
|
||||||
local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"')
|
local tags=$(echo "$response" | jq -r '.results[] | select(.name | startswith("nightly-")) | "\(.last_updated)|\(.name)"')
|
||||||
@ -43,7 +64,9 @@ delete_tag() {
|
|||||||
echo "Deleting tag: $tag_name"
|
echo "Deleting tag: $tag_name"
|
||||||
|
|
||||||
local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name"
|
local delete_url="https://hub.docker.com/v2/repositories/vllm/vllm-openai/tags/$tag_name"
|
||||||
local response=$(curl -s -X DELETE -H "Authorization: Bearer $DOCKERHUB_TOKEN" "$delete_url")
|
set +x
|
||||||
|
local response=$(curl -s -X DELETE -H "Authorization: Bearer $BEARER_TOKEN" "$delete_url")
|
||||||
|
set -x
|
||||||
|
|
||||||
if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then
|
if echo "$response" | jq -e '.detail' > /dev/null 2>&1; then
|
||||||
echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')"
|
echo "Warning: Failed to delete tag $tag_name: $(echo "$response" | jq -r '.detail')"
|
||||||
|
191
.buildkite/scripts/hardware_ci/run-npu-test.sh
Normal file
191
.buildkite/scripts/hardware_ci/run-npu-test.sh
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This script build the Ascend NPU docker image and run the offline inference inside the container.
|
||||||
|
# It serves a sanity check for compilation and basic model usage.
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
# Base ubuntu image with basic ascend development libraries and python installed
|
||||||
|
VLLM_ASCEND_REPO="https://github.com/vllm-project/vllm-ascend.git"
|
||||||
|
CONFIG_FILE_REMOTE_PATH="tests/e2e/vllm_interface/vllm_test.cfg"
|
||||||
|
TEST_RUN_CONFIG_FILE="vllm_test.cfg"
|
||||||
|
VLLM_ASCEND_TMP_DIR=
|
||||||
|
# Get the test run configuration file from the vllm-ascend repository
|
||||||
|
fetch_vllm_test_cfg() {
|
||||||
|
VLLM_ASCEND_TMP_DIR=$(mktemp -d)
|
||||||
|
# Ensure that the temporary directory is cleaned up when an exception occurs during configuration file retrieval
|
||||||
|
cleanup() {
|
||||||
|
rm -rf "${VLLM_ASCEND_TMP_DIR}"
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
GIT_TRACE=1 git clone -v --depth 1 "${VLLM_ASCEND_REPO}" "${VLLM_ASCEND_TMP_DIR}"
|
||||||
|
if [ ! -f "${VLLM_ASCEND_TMP_DIR}/${CONFIG_FILE_REMOTE_PATH}" ]; then
|
||||||
|
echo "Error: file '${CONFIG_FILE_REMOTE_PATH}' does not exist in the warehouse" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If the file already exists locally, just overwrite it
|
||||||
|
cp "${VLLM_ASCEND_TMP_DIR}/${CONFIG_FILE_REMOTE_PATH}" "${TEST_RUN_CONFIG_FILE}"
|
||||||
|
echo "Copied ${CONFIG_FILE_REMOTE_PATH} to ${TEST_RUN_CONFIG_FILE}"
|
||||||
|
|
||||||
|
# Since the trap will be overwritten later, and when it is executed here, the task of cleaning up resources
|
||||||
|
# when the trap is abnormal has been completed, so the temporary resources are manually deleted here.
|
||||||
|
rm -rf "${VLLM_ASCEND_TMP_DIR}"
|
||||||
|
trap - EXIT
|
||||||
|
}
|
||||||
|
|
||||||
|
# Downloads test run configuration file from a remote URL.
|
||||||
|
# Loads the configuration into the current script environment.
|
||||||
|
get_config() {
|
||||||
|
if [ ! -f "${TEST_RUN_CONFIG_FILE}" ]; then
|
||||||
|
echo "Error: file '${TEST_RUN_CONFIG_FILE}' does not exist in the warehouse" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
source "${TEST_RUN_CONFIG_FILE}"
|
||||||
|
echo "Base docker image name that get from configuration: ${BASE_IMAGE_NAME}"
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# get test running configuration.
|
||||||
|
fetch_vllm_test_cfg
|
||||||
|
get_config
|
||||||
|
# Check if the function call was successful. If not, exit the script.
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
image_name="npu/vllm-ci:${BUILDKITE_COMMIT}_${EPOCHSECONDS}"
|
||||||
|
container_name="npu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||||
|
|
||||||
|
# BUILDKITE_AGENT_NAME format is {hostname}-{agent_idx}-{npu_card_num}cards
|
||||||
|
agent_idx=$(echo "${BUILDKITE_AGENT_NAME}" | awk -F'-' '{print $(NF-1)}')
|
||||||
|
echo "agent_idx: ${agent_idx}"
|
||||||
|
builder_name="cachebuilder${agent_idx}"
|
||||||
|
builder_cache_dir="/mnt/docker-cache${agent_idx}"
|
||||||
|
mkdir -p ${builder_cache_dir}
|
||||||
|
|
||||||
|
# Try building the docker image
|
||||||
|
cat <<EOF | DOCKER_BUILDKIT=1 docker build \
|
||||||
|
--add-host cache-service-vllm.nginx-pypi-cache.svc.cluster.local:${PYPI_CACHE_HOST} \
|
||||||
|
--builder ${builder_name} --cache-from type=local,src=${builder_cache_dir} \
|
||||||
|
--cache-to type=local,dest=${builder_cache_dir},mode=max \
|
||||||
|
--progress=plain --load -t ${image_name} -f - .
|
||||||
|
FROM ${BASE_IMAGE_NAME}
|
||||||
|
|
||||||
|
# Define environments
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
RUN pip config set global.index-url http://cache-service-vllm.nginx-pypi-cache.svc.cluster.local:${PYPI_CACHE_PORT}/pypi/simple && \
|
||||||
|
pip config set global.trusted-host cache-service-vllm.nginx-pypi-cache.svc.cluster.local && \
|
||||||
|
apt-get update -y && \
|
||||||
|
apt-get install -y python3-pip git vim wget net-tools gcc g++ cmake libnuma-dev && \
|
||||||
|
rm -rf /var/cache/apt/* && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install for pytest to make the docker build cache layer always valid
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install pytest>=6.0 modelscope
|
||||||
|
|
||||||
|
WORKDIR /workspace/vllm
|
||||||
|
|
||||||
|
# Install vLLM dependencies in advance. Effect: As long as common.txt remains unchanged, the docker cache layer will be valid.
|
||||||
|
COPY requirements/common.txt /workspace/vllm/requirements/common.txt
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install -r requirements/common.txt
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# Install vLLM
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
VLLM_TARGET_DEVICE="empty" python3 -m pip install -v -e /workspace/vllm/ --extra-index https://download.pytorch.org/whl/cpu/ && \
|
||||||
|
python3 -m pip uninstall -y triton
|
||||||
|
|
||||||
|
# Install vllm-ascend
|
||||||
|
WORKDIR /workspace
|
||||||
|
ARG VLLM_ASCEND_REPO=https://github.com/vllm-project/vllm-ascend.git
|
||||||
|
ARG VLLM_ASCEND_TAG=main
|
||||||
|
RUN git config --global url."https://gh-proxy.test.osinfra.cn/https://github.com/".insteadOf "https://github.com/" && \
|
||||||
|
git clone --depth 1 \$VLLM_ASCEND_REPO --branch \$VLLM_ASCEND_TAG /workspace/vllm-ascend
|
||||||
|
|
||||||
|
# Install vllm dependencies in advance. Effect: As long as common.txt remains unchanged, the docker cache layer will be valid.
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
pip install -r /workspace/vllm-ascend/requirements.txt
|
||||||
|
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
|
export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi && \
|
||||||
|
source /usr/local/Ascend/ascend-toolkit/set_env.sh && \
|
||||||
|
source /usr/local/Ascend/nnal/atb/set_env.sh && \
|
||||||
|
export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/`uname -i`-linux/devlib && \
|
||||||
|
python3 -m pip install -v -e /workspace/vllm-ascend/ --extra-index https://download.pytorch.org/whl/cpu/
|
||||||
|
|
||||||
|
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
ENV VLLM_USE_MODELSCOPE=True
|
||||||
|
|
||||||
|
WORKDIR /workspace/vllm-ascend
|
||||||
|
|
||||||
|
CMD ["/bin/bash"]
|
||||||
|
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Setup cleanup
|
||||||
|
remove_docker_container() {
|
||||||
|
docker rm -f "${container_name}" || true;
|
||||||
|
docker image rm -f "${image_name}" || true;
|
||||||
|
docker system prune -f || true;
|
||||||
|
}
|
||||||
|
trap remove_docker_container EXIT
|
||||||
|
|
||||||
|
# Generate corresponding --device args based on BUILDKITE_AGENT_NAME
|
||||||
|
# Ascend NPU BUILDKITE_AGENT_NAME format is {hostname}-{agent_idx}-{npu_card_num}cards, and agent_idx starts from 1.
|
||||||
|
# e.g. atlas-a2-001-1-2cards means this is the 1-th agent on atlas-a2-001 host, and it has 2 NPU cards.
|
||||||
|
# returns --device /dev/davinci0 --device /dev/davinci1
|
||||||
|
parse_and_gen_devices() {
|
||||||
|
local input="$1"
|
||||||
|
local index cards_num
|
||||||
|
if [[ "$input" =~ ([0-9]+)-([0-9]+)cards$ ]]; then
|
||||||
|
index="${BASH_REMATCH[1]}"
|
||||||
|
cards_num="${BASH_REMATCH[2]}"
|
||||||
|
else
|
||||||
|
echo "parse error" >&2
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
local devices=""
|
||||||
|
local i=0
|
||||||
|
while (( i < cards_num )); do
|
||||||
|
local dev_idx=$(((index - 1)*cards_num + i ))
|
||||||
|
devices="$devices --device /dev/davinci${dev_idx}"
|
||||||
|
((i++))
|
||||||
|
done
|
||||||
|
|
||||||
|
# trim leading space
|
||||||
|
devices="${devices#"${devices%%[![:space:]]*}"}"
|
||||||
|
# Output devices: assigned to the caller variable
|
||||||
|
printf '%s' "$devices"
|
||||||
|
}
|
||||||
|
|
||||||
|
devices=$(parse_and_gen_devices "${BUILDKITE_AGENT_NAME}") || exit 1
|
||||||
|
|
||||||
|
# Run the image and execute the Out-Of-Tree (OOT) platform interface test case on Ascend NPU hardware.
|
||||||
|
# This test checks whether the OOT platform interface is functioning properly in conjunction with
|
||||||
|
# the hardware plugin vllm-ascend.
|
||||||
|
model_cache_dir=/mnt/modelscope${agent_idx}
|
||||||
|
mkdir -p ${model_cache_dir}
|
||||||
|
docker run \
|
||||||
|
${devices} \
|
||||||
|
--device /dev/davinci_manager \
|
||||||
|
--device /dev/devmm_svm \
|
||||||
|
--device /dev/hisi_hdc \
|
||||||
|
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||||
|
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||||
|
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
||||||
|
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
||||||
|
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||||
|
-v ${model_cache_dir}:/root/.cache/modelscope \
|
||||||
|
--entrypoint="" \
|
||||||
|
--name "${container_name}" \
|
||||||
|
"${image_name}" \
|
||||||
|
bash -c '
|
||||||
|
set -e
|
||||||
|
pytest -v -s tests/e2e/vllm_interface/
|
||||||
|
'
|
@ -64,10 +64,9 @@ python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git
|
|||||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||||
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
|
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
|
||||||
echo "--- Python dependencies installed ---"
|
echo "--- Python dependencies installed ---"
|
||||||
export VLLM_USE_V1=1
|
|
||||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||||
export VLLM_XLA_CACHE_PATH=
|
export VLLM_XLA_CACHE_PATH=
|
||||||
echo "Using VLLM V1"
|
|
||||||
|
|
||||||
echo "--- Hardware Information ---"
|
echo "--- Hardware Information ---"
|
||||||
# tpu-info
|
# tpu-info
|
||||||
|
@ -64,10 +64,9 @@ python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git
|
|||||||
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
&& python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \
|
||||||
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
|
&& python3 -m pip install --progress-bar off hf-transfer tblib==3.1.0
|
||||||
echo "--- Python dependencies installed ---"
|
echo "--- Python dependencies installed ---"
|
||||||
export VLLM_USE_V1=1
|
|
||||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||||
export VLLM_XLA_CACHE_PATH=
|
export VLLM_XLA_CACHE_PATH=
|
||||||
echo "Using VLLM V1"
|
|
||||||
|
|
||||||
echo "--- Hardware Information ---"
|
echo "--- Hardware Information ---"
|
||||||
# tpu-info
|
# tpu-info
|
||||||
|
@ -42,9 +42,8 @@ docker run \
|
|||||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||||
pytest -v -s v1/structured_output
|
pytest -v -s v1/structured_output
|
||||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
|
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py
|
||||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
|
||||||
|
pytest -v -s v1/test_metrics
|
||||||
pytest -v -s v1/test_serial_utils.py
|
pytest -v -s v1/test_serial_utils.py
|
||||||
pytest -v -s v1/test_utils.py
|
|
||||||
pytest -v -s v1/test_metrics_reader.py
|
|
||||||
'
|
'
|
||||||
|
@ -18,7 +18,7 @@ vllm bench throughput --input-len 256 --output-len 256 --output-json throughput_
|
|||||||
bench_throughput_exit_code=$?
|
bench_throughput_exit_code=$?
|
||||||
|
|
||||||
# run server-based benchmarks and upload the result to buildkite
|
# run server-based benchmarks and upload the result to buildkite
|
||||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
|
vllm serve meta-llama/Llama-2-7b-chat-hf &
|
||||||
server_pid=$!
|
server_pid=$!
|
||||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||||
|
|
||||||
|
@ -9,6 +9,6 @@ MAX_NUM_BATCHED_TOKENS=1024
|
|||||||
TENSOR_PARALLEL_SIZE=1
|
TENSOR_PARALLEL_SIZE=1
|
||||||
MAX_MODEL_LEN=2048
|
MAX_MODEL_LEN=2048
|
||||||
DOWNLOAD_DIR=/mnt/disks/persist
|
DOWNLOAD_DIR=/mnt/disks/persist
|
||||||
EXPECTED_THROUGHPUT=10.0
|
EXPECTED_THROUGHPUT=8.7
|
||||||
INPUT_LEN=1800
|
INPUT_LEN=1800
|
||||||
OUTPUT_LEN=128
|
OUTPUT_LEN=128
|
||||||
|
@ -42,7 +42,7 @@ echo "lanching vllm..."
|
|||||||
echo "logging to $VLLM_LOG"
|
echo "logging to $VLLM_LOG"
|
||||||
echo
|
echo
|
||||||
|
|
||||||
VLLM_USE_V1=1 vllm serve $MODEL \
|
vllm serve $MODEL \
|
||||||
--seed 42 \
|
--seed 42 \
|
||||||
--max-num-seqs $MAX_NUM_SEQS \
|
--max-num-seqs $MAX_NUM_SEQS \
|
||||||
--max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \
|
--max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \
|
||||||
|
@ -50,19 +50,28 @@ steps:
|
|||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
|
- tests/multimodal
|
||||||
|
- tests/utils_
|
||||||
|
commands:
|
||||||
|
- pytest -v -s -m 'not cpu_test' multimodal
|
||||||
|
- pytest -v -s utils_
|
||||||
|
|
||||||
|
- label: Async Engine, Inputs, Utils, Worker Test (CPU) # 4 mins
|
||||||
|
timeout_in_minutes: 10
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
- tests/test_inputs.py
|
- tests/test_inputs.py
|
||||||
- tests/test_outputs.py
|
- tests/test_outputs.py
|
||||||
- tests/multimodal
|
- tests/multimodal
|
||||||
- tests/utils_
|
|
||||||
- tests/standalone_tests/lazy_imports.py
|
- tests/standalone_tests/lazy_imports.py
|
||||||
- tests/transformers_utils
|
- tests/transformers_utils
|
||||||
|
no_gpu: true
|
||||||
commands:
|
commands:
|
||||||
- python3 standalone_tests/lazy_imports.py
|
- python3 standalone_tests/lazy_imports.py
|
||||||
- pytest -v -s test_inputs.py
|
- pytest -v -s test_inputs.py
|
||||||
- pytest -v -s test_outputs.py
|
- pytest -v -s test_outputs.py
|
||||||
- pytest -v -s multimodal
|
- pytest -v -s -m 'cpu_test' multimodal
|
||||||
- pytest -v -s utils_ # Utils
|
- pytest -v -s transformers_utils
|
||||||
- pytest -v -s transformers_utils # transformers_utils
|
|
||||||
|
|
||||||
- label: Python-only Installation Test # 10min
|
- label: Python-only Installation Test # 10min
|
||||||
timeout_in_minutes: 20
|
timeout_in_minutes: 20
|
||||||
@ -159,10 +168,7 @@ steps:
|
|||||||
- examples/offline_inference/rlhf.py
|
- examples/offline_inference/rlhf.py
|
||||||
- examples/offline_inference/rlhf_colocate.py
|
- examples/offline_inference/rlhf_colocate.py
|
||||||
- tests/examples/offline_inference/data_parallel.py
|
- tests/examples/offline_inference/data_parallel.py
|
||||||
- tests/v1/test_async_llm_dp.py
|
- tests/v1/distributed
|
||||||
- tests/v1/test_external_lb_dp.py
|
|
||||||
- tests/v1/test_internal_lb_dp.py
|
|
||||||
- tests/v1/test_hybrid_lb_dp.py
|
|
||||||
- tests/v1/engine/test_engine_core_client.py
|
- tests/v1/engine/test_engine_core_client.py
|
||||||
- tests/distributed/test_symm_mem_allreduce.py
|
- tests/distributed/test_symm_mem_allreduce.py
|
||||||
commands:
|
commands:
|
||||||
@ -180,10 +186,10 @@ steps:
|
|||||||
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
- TP_SIZE=2 DP_SIZE=2 ENABLE_EP=1 torchrun --nproc-per-node=4 distributed/test_torchrun_example_moe.py
|
||||||
# test with internal dp
|
# test with internal dp
|
||||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_internal_lb_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/distributed/test_hybrid_lb_dp.py
|
||||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||||
- pytest -v -s distributed/test_utils.py
|
- pytest -v -s distributed/test_utils.py
|
||||||
- pytest -v -s compile/test_basic_correctness.py
|
- pytest -v -s compile/test_basic_correctness.py
|
||||||
@ -290,26 +296,35 @@ steps:
|
|||||||
- tests/v1
|
- tests/v1
|
||||||
commands:
|
commands:
|
||||||
# split the test to avoid interference
|
# split the test to avoid interference
|
||||||
- pytest -v -s v1/core
|
- pytest -v -s -m 'not cpu_test' v1/core
|
||||||
- pytest -v -s v1/executor
|
- pytest -v -s v1/executor
|
||||||
- pytest -v -s v1/kv_offload
|
- pytest -v -s v1/kv_offload
|
||||||
- pytest -v -s v1/sample
|
- pytest -v -s v1/sample
|
||||||
- pytest -v -s v1/logits_processors
|
- pytest -v -s v1/logits_processors
|
||||||
- pytest -v -s v1/worker
|
- pytest -v -s v1/worker
|
||||||
- pytest -v -s v1/structured_output
|
|
||||||
- pytest -v -s v1/spec_decode
|
- pytest -v -s v1/spec_decode
|
||||||
- pytest -v -s v1/kv_connector/unit
|
- pytest -v -s -m 'not cpu_test' v1/kv_connector/unit
|
||||||
- pytest -v -s v1/metrics
|
- pytest -v -s -m 'not cpu_test' v1/metrics
|
||||||
- pytest -v -s v1/test_kv_sharing.py
|
|
||||||
- pytest -v -s v1/test_metrics_reader.py
|
|
||||||
- pytest -v -s v1/test_oracle.py
|
- pytest -v -s v1/test_oracle.py
|
||||||
- pytest -v -s v1/test_request.py
|
- pytest -v -s v1/test_request.py
|
||||||
- pytest -v -s v1/test_serial_utils.py
|
|
||||||
- pytest -v -s v1/test_utils.py
|
|
||||||
# Integration test for streaming correctness (requires special branch).
|
# Integration test for streaming correctness (requires special branch).
|
||||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||||
|
|
||||||
|
- label: V1 Test others (CPU) # 5 mins
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/v1
|
||||||
|
no_gpu: true
|
||||||
|
commands:
|
||||||
|
# split the test to avoid interference
|
||||||
|
- pytest -v -s -m 'cpu_test' v1/core
|
||||||
|
- pytest -v -s v1/structured_output
|
||||||
|
- pytest -v -s v1/test_serial_utils.py
|
||||||
|
- pytest -v -s -m 'cpu_test' v1/kv_connector/unit
|
||||||
|
- pytest -v -s -m 'cpu_test' v1/metrics
|
||||||
|
|
||||||
|
|
||||||
- label: Examples Test # 30min
|
- label: Examples Test # 30min
|
||||||
timeout_in_minutes: 45
|
timeout_in_minutes: 45
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
@ -383,9 +398,8 @@ steps:
|
|||||||
- pytest -v -s compile/test_pass_manager.py
|
- pytest -v -s compile/test_pass_manager.py
|
||||||
- pytest -v -s compile/test_fusion.py
|
- pytest -v -s compile/test_fusion.py
|
||||||
- pytest -v -s compile/test_fusion_attn.py
|
- pytest -v -s compile/test_fusion_attn.py
|
||||||
|
- pytest -v -s compile/test_functionalization.py
|
||||||
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
- pytest -v -s compile/test_silu_mul_quant_fusion.py
|
||||||
- pytest -v -s compile/test_sequence_parallelism.py
|
|
||||||
- pytest -v -s compile/test_async_tp.py
|
|
||||||
- pytest -v -s compile/test_fusion_all_reduce.py
|
- pytest -v -s compile/test_fusion_all_reduce.py
|
||||||
- pytest -v -s compile/test_decorator.py
|
- pytest -v -s compile/test_decorator.py
|
||||||
- pytest -v -s compile/test_noop_elimination.py
|
- pytest -v -s compile/test_noop_elimination.py
|
||||||
@ -417,8 +431,9 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/
|
- csrc/
|
||||||
- tests/kernels/core
|
- tests/kernels/core
|
||||||
|
- tests/kernels/test_top_k_per_row.py
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/core
|
- pytest -v -s kernels/core kernels/test_top_k_per_row.py
|
||||||
|
|
||||||
- label: Kernels Attention Test %N # 23min
|
- label: Kernels Attention Test %N # 23min
|
||||||
timeout_in_minutes: 35
|
timeout_in_minutes: 35
|
||||||
@ -462,32 +477,22 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- csrc/mamba/
|
- csrc/mamba/
|
||||||
- tests/kernels/mamba
|
- tests/kernels/mamba
|
||||||
|
- vllm/model_executor/layers/mamba/ops
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s kernels/mamba
|
- pytest -v -s kernels/mamba
|
||||||
|
|
||||||
- label: Tensorizer Test # 14min
|
- label: Model Executor Test # 23min
|
||||||
timeout_in_minutes: 25
|
timeout_in_minutes: 35
|
||||||
mirror_hardwares: [amdexperimental]
|
|
||||||
source_file_dependencies:
|
|
||||||
- vllm/model_executor/model_loader
|
|
||||||
- tests/tensorizer_loader
|
|
||||||
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
|
||||||
commands:
|
|
||||||
- apt-get update && apt-get install -y curl libsodium23
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
- pytest -v -s tensorizer_loader
|
|
||||||
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
|
||||||
|
|
||||||
- label: Model Executor Test # 7min
|
|
||||||
timeout_in_minutes: 20
|
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/model_executor
|
- vllm/model_executor
|
||||||
- tests/model_executor
|
- tests/model_executor
|
||||||
|
- tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||||
commands:
|
commands:
|
||||||
- apt-get update && apt-get install -y curl libsodium23
|
- apt-get update && apt-get install -y curl libsodium23
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
- pytest -v -s model_executor
|
- pytest -v -s model_executor
|
||||||
|
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
|
||||||
|
|
||||||
- label: Benchmarks # 11min
|
- label: Benchmarks # 11min
|
||||||
timeout_in_minutes: 20
|
timeout_in_minutes: 20
|
||||||
@ -522,7 +527,7 @@ steps:
|
|||||||
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
# https://github.com/pytorch/ao/issues/2919, we'll have to skip new torchao tests for now
|
||||||
# we can only upgrade after this is resolved
|
# we can only upgrade after this is resolved
|
||||||
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
- pip install --pre torchao==0.13.0.dev20250814 --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||||
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization
|
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization/
|
||||||
|
|
||||||
- label: LM Eval Small Models # 53min
|
- label: LM Eval Small Models # 53min
|
||||||
timeout_in_minutes: 75
|
timeout_in_minutes: 75
|
||||||
@ -550,10 +555,17 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/tool_use
|
- tests/tool_use
|
||||||
- tests/mistral_tool_use
|
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s tool_use
|
- pytest -v -s -m 'not cpu_test' tool_use
|
||||||
- pytest -v -s mistral_tool_use
|
|
||||||
|
- label: OpenAI-Compatible Tool Use (CPU) # 5 mins
|
||||||
|
timeout_in_minutes: 10
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/tool_use
|
||||||
|
no_gpu: true
|
||||||
|
commands:
|
||||||
|
- pytest -v -s -m 'cpu_test' tool_use
|
||||||
|
|
||||||
##### models test #####
|
##### models test #####
|
||||||
|
|
||||||
@ -593,13 +605,19 @@ steps:
|
|||||||
- vllm/
|
- vllm/
|
||||||
- tests/models/test_transformers.py
|
- tests/models/test_transformers.py
|
||||||
- tests/models/test_registry.py
|
- tests/models/test_registry.py
|
||||||
|
commands:
|
||||||
|
- pytest -v -s models/test_transformers.py models/test_registry.py
|
||||||
|
|
||||||
|
- label: Basic Models Test (Other CPU) # 5min
|
||||||
|
timeout_in_minutes: 10
|
||||||
|
torch_nightly: true
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
- tests/models/test_utils.py
|
- tests/models/test_utils.py
|
||||||
- tests/models/test_vision.py
|
- tests/models/test_vision.py
|
||||||
|
no_gpu: true
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s models/test_transformers.py \
|
- pytest -v -s models/test_utils.py models/test_vision.py
|
||||||
models/test_registry.py \
|
|
||||||
models/test_utils.py \
|
|
||||||
models/test_vision.py
|
|
||||||
|
|
||||||
- label: Language Models Tests (Standard)
|
- label: Language Models Tests (Standard)
|
||||||
timeout_in_minutes: 25
|
timeout_in_minutes: 25
|
||||||
@ -769,6 +787,7 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||||
- pytest -v -s tests/models/test_initialization.py
|
- pytest -v -s tests/models/test_initialization.py
|
||||||
|
- pytest -v -s tests/models/test_transformers.py
|
||||||
- pytest -v -s tests/models/multimodal/processing/
|
- pytest -v -s tests/models/multimodal/processing/
|
||||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||||
- python3 examples/offline_inference/basic/chat.py
|
- python3 examples/offline_inference/basic/chat.py
|
||||||
@ -809,18 +828,20 @@ steps:
|
|||||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
- pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py
|
||||||
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
- pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py
|
||||||
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
- pytest -v -s tests/kernels/moe/test_nvfp4_moe.py
|
||||||
- pytest -v -s tests/kernels/moe/test_mxfp4_moe.py
|
- pytest -v -s tests/kernels/moe/test_ocp_mx_moe.py
|
||||||
# Fusion
|
# Fusion
|
||||||
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
- pytest -v -s tests/compile/test_fusion_all_reduce.py
|
||||||
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
- pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern
|
||||||
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
- pytest -v -s tests/kernels/moe/test_flashinfer.py
|
||||||
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
- pytest -v -s tests/compile/test_silu_mul_quant_fusion.py
|
||||||
|
- pytest -v -s tests/kernels/quantization/test_nvfp4_qutlass.py
|
||||||
|
- pytest -v -s tests/kernels/quantization/test_mxfp4_qutlass.py
|
||||||
|
|
||||||
- label: GPT-OSS Eval (Blackwell)
|
- label: Blackwell GPT-OSS Eval
|
||||||
timeout_in_minutes: 60
|
timeout_in_minutes: 60
|
||||||
working_dir: "/vllm-workspace/"
|
working_dir: "/vllm-workspace/"
|
||||||
gpu: b200
|
gpu: b200
|
||||||
optional: true # disable while debugging
|
optional: true # run on nightlies
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- tests/evals/gpt_oss
|
- tests/evals/gpt_oss
|
||||||
- vllm/model_executor/models/gpt_oss.py
|
- vllm/model_executor/models/gpt_oss.py
|
||||||
@ -828,7 +849,34 @@ steps:
|
|||||||
- vllm/v1/attention/backends/flashinfer.py
|
- vllm/v1/attention/backends/flashinfer.py
|
||||||
commands:
|
commands:
|
||||||
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
- uv pip install --system 'gpt-oss[eval]==0.0.5'
|
||||||
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58 --server-args '--tensor-parallel-size 2'
|
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
|
||||||
|
|
||||||
|
- label: Blackwell Quantized MoE Test
|
||||||
|
timeout_in_minutes: 60
|
||||||
|
working_dir: "/vllm-workspace/"
|
||||||
|
gpu: b200
|
||||||
|
source_file_dependencies:
|
||||||
|
- tests/quantization/test_blackwell_moe.py
|
||||||
|
- vllm/model_executor/models/deepseek_v2.py
|
||||||
|
- vllm/model_executor/models/gpt_oss.py
|
||||||
|
- vllm/model_executor/models/llama4.py
|
||||||
|
- vllm/model_executor/layers/fused_moe
|
||||||
|
- vllm/model_executor/layers/quantization/compressed_tensors
|
||||||
|
- vllm/model_executor/layers/quantization/modelopt.py
|
||||||
|
- vllm/model_executor/layers/quantization/mxfp4.py
|
||||||
|
- vllm/v1/attention/backends/flashinfer.py
|
||||||
|
commands:
|
||||||
|
- pytest -s -v tests/quantization/test_blackwell_moe.py
|
||||||
|
|
||||||
|
- label: Blackwell LM Eval Small Models
|
||||||
|
timeout_in_minutes: 120
|
||||||
|
gpu: b200
|
||||||
|
optional: true # run on nightlies
|
||||||
|
source_file_dependencies:
|
||||||
|
- csrc/
|
||||||
|
- vllm/model_executor/layers/quantization
|
||||||
|
commands:
|
||||||
|
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-blackwell.txt --tp-size=1
|
||||||
|
|
||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
##### multi gpus test #####
|
##### multi gpus test #####
|
||||||
@ -889,14 +937,13 @@ steps:
|
|||||||
- tests/compile/test_wrapper.py
|
- tests/compile/test_wrapper.py
|
||||||
- tests/distributed/
|
- tests/distributed/
|
||||||
- tests/entrypoints/llm/test_collective_rpc.py
|
- tests/entrypoints/llm/test_collective_rpc.py
|
||||||
- tests/v1/test_async_llm_dp.py
|
- tests/v1/distributed
|
||||||
- tests/v1/test_external_lb_dp.py
|
|
||||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- tests/v1/shutdown
|
- tests/v1/shutdown
|
||||||
- tests/v1/worker/test_worker_memory_snapshot.py
|
- tests/v1/worker/test_worker_memory_snapshot.py
|
||||||
commands:
|
commands:
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_async_llm_dp.py
|
||||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/distributed/test_external_lb_dp.py
|
||||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||||
- pytest -v -s ./compile/test_basic_correctness.py
|
- pytest -v -s ./compile/test_basic_correctness.py
|
||||||
@ -1047,6 +1094,8 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/"
|
working_dir: "/vllm-workspace/"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
|
- pytest -v -s tests/compile/test_async_tp.py
|
||||||
|
- pytest -v -s tests/compile/test_sequence_parallelism.py
|
||||||
- pytest -v -s tests/distributed/test_context_parallel.py
|
- pytest -v -s tests/distributed/test_context_parallel.py
|
||||||
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048
|
||||||
|
|
||||||
|
11
.github/CODEOWNERS
vendored
11
.github/CODEOWNERS
vendored
@ -12,8 +12,6 @@
|
|||||||
/vllm/model_executor/layers/mamba @tdoublep
|
/vllm/model_executor/layers/mamba @tdoublep
|
||||||
/vllm/model_executor/model_loader @22quinn
|
/vllm/model_executor/model_loader @22quinn
|
||||||
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche
|
/vllm/multimodal @DarkLight1337 @ywang96 @NickLucche
|
||||||
/vllm/v1/attention @LucasWilkinson
|
|
||||||
/vllm/v1/sample @22quinn @houseroad
|
|
||||||
/vllm/vllm_flash_attn @LucasWilkinson
|
/vllm/vllm_flash_attn @LucasWilkinson
|
||||||
/vllm/lora @jeejeelee
|
/vllm/lora @jeejeelee
|
||||||
/vllm/reasoning @aarnphm @chaunceyjiang
|
/vllm/reasoning @aarnphm @chaunceyjiang
|
||||||
@ -25,14 +23,17 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||||
# so spam a lot of people
|
# so spam a lot of people
|
||||||
/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg
|
/vllm/config @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg
|
||||||
|
/vllm/config/cache.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor @yewentao256 @ProExpertProg @heheda12345
|
||||||
|
|
||||||
# vLLM V1
|
# vLLM V1
|
||||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
/vllm/v1/attention @LucasWilkinson
|
||||||
/vllm/v1/spec_decode @benchislett @luccafong
|
|
||||||
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
/vllm/v1/attention/backends/flashinfer.py @mgoin
|
||||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||||
|
/vllm/v1/sample @22quinn @houseroad @njhill
|
||||||
|
/vllm/v1/spec_decode @benchislett @luccafong
|
||||||
|
/vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett
|
||||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||||
/vllm/v1/offloading @ApostaC
|
/vllm/v1/offloading @ApostaC
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
|||||||
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
/tests/weight_loading @mgoin @youkaichao @yewentao256
|
||||||
/tests/lora @jeejeelee
|
/tests/lora @jeejeelee
|
||||||
/tests/models/language/generation/test_hybrid.py @tdoublep
|
/tests/models/language/generation/test_hybrid.py @tdoublep
|
||||||
/tests/v1/kv_connector/nixl_integration @NickLucche
|
/tests/v1/kv_connector/nixl_integration @NickLucche
|
||||||
/tests/v1/kv_connector @ApostaC
|
/tests/v1/kv_connector @ApostaC
|
||||||
/tests/v1/offloading @ApostaC
|
/tests/v1/offloading @ApostaC
|
||||||
|
|
||||||
|
35
.github/mergify.yml
vendored
35
.github/mergify.yml
vendored
@ -2,6 +2,7 @@ pull_request_rules:
|
|||||||
- name: label-documentation
|
- name: label-documentation
|
||||||
description: Automatically apply documentation label
|
description: Automatically apply documentation label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^[^/]+\.md$
|
- files~=^[^/]+\.md$
|
||||||
- files~=^docs/
|
- files~=^docs/
|
||||||
@ -10,10 +11,13 @@ pull_request_rules:
|
|||||||
label:
|
label:
|
||||||
add:
|
add:
|
||||||
- documentation
|
- documentation
|
||||||
|
comment:
|
||||||
|
message: "Documentation preview: https://vllm--{{number}}.org.readthedocs.build/en/{{number}}/"
|
||||||
|
|
||||||
- name: label-ci-build
|
- name: label-ci-build
|
||||||
description: Automatically apply ci/build label
|
description: Automatically apply ci/build label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^\.github/
|
- files~=^\.github/
|
||||||
- files~=\.buildkite/
|
- files~=\.buildkite/
|
||||||
@ -30,6 +34,7 @@ pull_request_rules:
|
|||||||
- name: label-deepseek
|
- name: label-deepseek
|
||||||
description: Automatically apply deepseek label
|
description: Automatically apply deepseek label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^examples/.*deepseek.*\.py
|
- files~=^examples/.*deepseek.*\.py
|
||||||
- files~=^tests/.*deepseek.*\.py
|
- files~=^tests/.*deepseek.*\.py
|
||||||
@ -46,6 +51,7 @@ pull_request_rules:
|
|||||||
- name: label-frontend
|
- name: label-frontend
|
||||||
description: Automatically apply frontend label
|
description: Automatically apply frontend label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- files~=^vllm/entrypoints/
|
- files~=^vllm/entrypoints/
|
||||||
actions:
|
actions:
|
||||||
label:
|
label:
|
||||||
@ -55,6 +61,7 @@ pull_request_rules:
|
|||||||
- name: label-llama
|
- name: label-llama
|
||||||
description: Automatically apply llama label
|
description: Automatically apply llama label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^examples/.*llama.*\.py
|
- files~=^examples/.*llama.*\.py
|
||||||
- files~=^tests/.*llama.*\.py
|
- files~=^tests/.*llama.*\.py
|
||||||
@ -70,6 +77,7 @@ pull_request_rules:
|
|||||||
- name: label-multi-modality
|
- name: label-multi-modality
|
||||||
description: Automatically apply multi-modality label
|
description: Automatically apply multi-modality label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^vllm/multimodal/
|
- files~=^vllm/multimodal/
|
||||||
- files~=^tests/multimodal/
|
- files~=^tests/multimodal/
|
||||||
@ -83,6 +91,7 @@ pull_request_rules:
|
|||||||
- name: label-new-model
|
- name: label-new-model
|
||||||
description: Automatically apply new-model label
|
description: Automatically apply new-model label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- and:
|
- and:
|
||||||
- files~=^vllm/model_executor/models/
|
- files~=^vllm/model_executor/models/
|
||||||
- files=vllm/model_executor/models/registry.py
|
- files=vllm/model_executor/models/registry.py
|
||||||
@ -94,6 +103,7 @@ pull_request_rules:
|
|||||||
- name: label-performance
|
- name: label-performance
|
||||||
description: Automatically apply performance label
|
description: Automatically apply performance label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^benchmarks/
|
- files~=^benchmarks/
|
||||||
- files~=^vllm/benchmarks/
|
- files~=^vllm/benchmarks/
|
||||||
@ -107,6 +117,7 @@ pull_request_rules:
|
|||||||
- name: label-qwen
|
- name: label-qwen
|
||||||
description: Automatically apply qwen label
|
description: Automatically apply qwen label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^examples/.*qwen.*\.py
|
- files~=^examples/.*qwen.*\.py
|
||||||
- files~=^tests/.*qwen.*\.py
|
- files~=^tests/.*qwen.*\.py
|
||||||
@ -121,6 +132,7 @@ pull_request_rules:
|
|||||||
- name: label-gpt-oss
|
- name: label-gpt-oss
|
||||||
description: Automatically apply gpt-oss label
|
description: Automatically apply gpt-oss label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^examples/.*gpt[-_]?oss.*\.py
|
- files~=^examples/.*gpt[-_]?oss.*\.py
|
||||||
- files~=^tests/.*gpt[-_]?oss.*\.py
|
- files~=^tests/.*gpt[-_]?oss.*\.py
|
||||||
@ -142,6 +154,7 @@ pull_request_rules:
|
|||||||
- name: label-rocm
|
- name: label-rocm
|
||||||
description: Automatically apply rocm label
|
description: Automatically apply rocm label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^csrc/rocm/
|
- files~=^csrc/rocm/
|
||||||
- files~=^docker/Dockerfile.rocm
|
- files~=^docker/Dockerfile.rocm
|
||||||
@ -162,6 +175,7 @@ pull_request_rules:
|
|||||||
- name: label-structured-output
|
- name: label-structured-output
|
||||||
description: Automatically apply structured-output label
|
description: Automatically apply structured-output label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^benchmarks/structured_schemas/
|
- files~=^benchmarks/structured_schemas/
|
||||||
- files=benchmarks/benchmark_serving_structured_output.py
|
- files=benchmarks/benchmark_serving_structured_output.py
|
||||||
@ -181,6 +195,7 @@ pull_request_rules:
|
|||||||
- name: label-speculative-decoding
|
- name: label-speculative-decoding
|
||||||
description: Automatically apply speculative-decoding label
|
description: Automatically apply speculative-decoding label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^vllm/v1/spec_decode/
|
- files~=^vllm/v1/spec_decode/
|
||||||
- files~=^tests/v1/spec_decode/
|
- files~=^tests/v1/spec_decode/
|
||||||
@ -196,6 +211,7 @@ pull_request_rules:
|
|||||||
- name: label-v1
|
- name: label-v1
|
||||||
description: Automatically apply v1 label
|
description: Automatically apply v1 label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^vllm/v1/
|
- files~=^vllm/v1/
|
||||||
- files~=^tests/v1/
|
- files~=^tests/v1/
|
||||||
@ -208,6 +224,7 @@ pull_request_rules:
|
|||||||
description: Automatically apply tpu label
|
description: Automatically apply tpu label
|
||||||
# Keep this list in sync with `label-tpu-remove` conditions
|
# Keep this list in sync with `label-tpu-remove` conditions
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=tpu.py
|
- files~=tpu.py
|
||||||
- files~=_tpu
|
- files~=_tpu
|
||||||
@ -223,6 +240,7 @@ pull_request_rules:
|
|||||||
description: Automatically remove tpu label
|
description: Automatically remove tpu label
|
||||||
# Keep this list in sync with `label-tpu` conditions
|
# Keep this list in sync with `label-tpu` conditions
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- and:
|
- and:
|
||||||
- -files~=tpu.py
|
- -files~=tpu.py
|
||||||
- -files~=_tpu
|
- -files~=_tpu
|
||||||
@ -237,9 +255,9 @@ pull_request_rules:
|
|||||||
- name: label-tool-calling
|
- name: label-tool-calling
|
||||||
description: Automatically add tool-calling label
|
description: Automatically add tool-calling label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^tests/tool_use/
|
- files~=^tests/tool_use/
|
||||||
- files~=^tests/mistral_tool_use/
|
|
||||||
- files~=^tests/entrypoints/openai/tool_parsers/
|
- files~=^tests/entrypoints/openai/tool_parsers/
|
||||||
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
|
- files=tests/entrypoints/openai/test_chat_with_tool_reasoning.py
|
||||||
- files~=^vllm/entrypoints/openai/tool_parsers/
|
- files~=^vllm/entrypoints/openai/tool_parsers/
|
||||||
@ -256,8 +274,9 @@ pull_request_rules:
|
|||||||
|
|
||||||
- name: ping author on conflicts and add 'needs-rebase' label
|
- name: ping author on conflicts and add 'needs-rebase' label
|
||||||
conditions:
|
conditions:
|
||||||
- conflict
|
- label != stale
|
||||||
- -closed
|
- conflict
|
||||||
|
- -closed
|
||||||
actions:
|
actions:
|
||||||
label:
|
label:
|
||||||
add:
|
add:
|
||||||
@ -271,10 +290,12 @@ pull_request_rules:
|
|||||||
|
|
||||||
- name: assign reviewer for tensorizer changes
|
- name: assign reviewer for tensorizer changes
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
|
- or:
|
||||||
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
- files~=^vllm/model_executor/model_loader/tensorizer.py
|
||||||
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
- files~=^vllm/model_executor/model_loader/tensorizer_loader.py
|
||||||
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
- files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py
|
||||||
- files~=^tests/tensorizer_loader/
|
- files~=^tests/model_executor/model_loader/tensorizer_loader/
|
||||||
actions:
|
actions:
|
||||||
assign:
|
assign:
|
||||||
users:
|
users:
|
||||||
@ -282,6 +303,7 @@ pull_request_rules:
|
|||||||
|
|
||||||
- name: assign reviewer for modelopt changes
|
- name: assign reviewer for modelopt changes
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^vllm/model_executor/layers/quantization/modelopt\.py$
|
- files~=^vllm/model_executor/layers/quantization/modelopt\.py$
|
||||||
- files~=^vllm/model_executor/layers/quantization/__init__\.py$
|
- files~=^vllm/model_executor/layers/quantization/__init__\.py$
|
||||||
@ -296,8 +318,8 @@ pull_request_rules:
|
|||||||
|
|
||||||
- name: remove 'needs-rebase' label when conflict is resolved
|
- name: remove 'needs-rebase' label when conflict is resolved
|
||||||
conditions:
|
conditions:
|
||||||
- -conflict
|
- -conflict
|
||||||
- -closed
|
- -closed
|
||||||
actions:
|
actions:
|
||||||
label:
|
label:
|
||||||
remove:
|
remove:
|
||||||
@ -306,6 +328,7 @@ pull_request_rules:
|
|||||||
- name: label-kv-connector
|
- name: label-kv-connector
|
||||||
description: Automatically apply kv-connector label
|
description: Automatically apply kv-connector label
|
||||||
conditions:
|
conditions:
|
||||||
|
- label != stale
|
||||||
- or:
|
- or:
|
||||||
- files~=^examples/online_serving/disaggregated[^/]*/.*
|
- files~=^examples/online_serving/disaggregated[^/]*/.*
|
||||||
- files~=^examples/offline_inference/disaggregated[^/]*/.*
|
- files~=^examples/offline_inference/disaggregated[^/]*/.*
|
||||||
|
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
|||||||
actions: write
|
actions: write
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v10.0.0
|
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v10.1.0
|
||||||
with:
|
with:
|
||||||
# Increasing this value ensures that changes to this workflow
|
# Increasing this value ensures that changes to this workflow
|
||||||
# propagate to all issues and PRs in days rather than months
|
# propagate to all issues and PRs in days rather than months
|
||||||
|
@ -6,30 +6,18 @@ default_stages:
|
|||||||
- manual # Run in CI
|
- manual # Run in CI
|
||||||
exclude: 'vllm/third_party/.*'
|
exclude: 'vllm/third_party/.*'
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/google/yapf
|
|
||||||
rev: v0.43.0
|
|
||||||
hooks:
|
|
||||||
- id: yapf
|
|
||||||
args: [--in-place, --verbose]
|
|
||||||
# Keep the same list from yapfignore here to avoid yapf failing without any inputs
|
|
||||||
exclude: '(.buildkite|benchmarks|build|examples)/.*'
|
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
rev: v0.11.7
|
rev: v0.14.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: ruff
|
- id: ruff-check
|
||||||
args: [--output-format, github, --fix]
|
args: [--output-format, github, --fix]
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
files: ^(.buildkite|benchmarks|examples)/.*
|
|
||||||
- repo: https://github.com/crate-ci/typos
|
- repo: https://github.com/crate-ci/typos
|
||||||
rev: v1.35.5
|
rev: v1.38.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: typos
|
- id: typos
|
||||||
- repo: https://github.com/PyCQA/isort
|
|
||||||
rev: 6.0.1
|
|
||||||
hooks:
|
|
||||||
- id: isort
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v20.1.3
|
rev: v21.1.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: clang-format
|
- id: clang-format
|
||||||
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
|
exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*'
|
||||||
@ -46,7 +34,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: actionlint
|
- id: actionlint
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||||
rev: 0.6.17
|
rev: 0.9.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: pip-compile
|
- id: pip-compile
|
||||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28]
|
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28]
|
||||||
@ -67,11 +55,6 @@ repos:
|
|||||||
types_or: [python, pyi]
|
types_or: [python, pyi]
|
||||||
require_serial: true
|
require_serial: true
|
||||||
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
|
additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic]
|
||||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
|
||||||
name: Run mypy for Python 3.9
|
|
||||||
entry: python tools/pre_commit/mypy.py 1 "3.9"
|
|
||||||
<<: *mypy_common
|
|
||||||
stages: [manual] # Only run in CI
|
|
||||||
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
- id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||||
name: Run mypy for Python 3.10
|
name: Run mypy for Python 3.10
|
||||||
entry: python tools/pre_commit/mypy.py 1 "3.10"
|
entry: python tools/pre_commit/mypy.py 1 "3.10"
|
||||||
@ -87,6 +70,11 @@ repos:
|
|||||||
entry: python tools/pre_commit/mypy.py 1 "3.12"
|
entry: python tools/pre_commit/mypy.py 1 "3.12"
|
||||||
<<: *mypy_common
|
<<: *mypy_common
|
||||||
stages: [manual] # Only run in CI
|
stages: [manual] # Only run in CI
|
||||||
|
- id: mypy-3.13 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||||
|
name: Run mypy for Python 3.13
|
||||||
|
entry: python tools/pre_commit/mypy.py 1 "3.13"
|
||||||
|
<<: *mypy_common
|
||||||
|
stages: [manual] # Only run in CI
|
||||||
- id: shellcheck
|
- id: shellcheck
|
||||||
name: Lint shell scripts
|
name: Lint shell scripts
|
||||||
entry: tools/shellcheck.sh
|
entry: tools/shellcheck.sh
|
||||||
|
109
CMakeLists.txt
109
CMakeLists.txt
@ -34,10 +34,10 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
|||||||
# Supported python versions. These versions will be searched in order, the
|
# Supported python versions. These versions will be searched in order, the
|
||||||
# first match will be selected. These should be kept in sync with setup.py.
|
# first match will be selected. These should be kept in sync with setup.py.
|
||||||
#
|
#
|
||||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13")
|
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
|
||||||
|
|
||||||
# Supported AMD GPU architectures.
|
# Supported AMD GPU architectures.
|
||||||
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
|
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151")
|
||||||
|
|
||||||
#
|
#
|
||||||
# Supported/expected torch versions for CUDA/ROCm.
|
# Supported/expected torch versions for CUDA/ROCm.
|
||||||
@ -86,6 +86,9 @@ find_package(Torch REQUIRED)
|
|||||||
# Supported NVIDIA architectures.
|
# Supported NVIDIA architectures.
|
||||||
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
|
# This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined
|
||||||
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
||||||
|
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||||
|
set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0")
|
||||||
|
elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
||||||
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
|
||||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
|
||||||
else()
|
else()
|
||||||
@ -175,6 +178,15 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Set compression mode for CUDA >=13.x.
|
||||||
|
#
|
||||||
|
if(VLLM_GPU_LANG STREQUAL "CUDA" AND
|
||||||
|
DEFINED CMAKE_CUDA_COMPILER_VERSION AND
|
||||||
|
CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "--compress-mode=size")
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Set CUDA include flags for CXX compiler.
|
# Set CUDA include flags for CXX compiler.
|
||||||
#
|
#
|
||||||
@ -257,8 +269,8 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/sampler.cu"
|
"csrc/sampler.cu"
|
||||||
"csrc/cuda_view.cu"
|
"csrc/cuda_view.cu"
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
"csrc/quantization/w8a8/int8/scaled_quant.cu"
|
||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/w8a8/fp8/common.cu"
|
||||||
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||||
"csrc/quantization/activation_kernels.cu"
|
"csrc/quantization/activation_kernels.cu"
|
||||||
@ -270,7 +282,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||||
|
|
||||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||||
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
|
set(CUTLASS_REVISION "v4.2.1" CACHE STRING "CUTLASS revision to use")
|
||||||
|
|
||||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||||
@ -302,13 +314,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/permute_cols.cu"
|
"csrc/permute_cols.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
|
||||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
|
||||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||||
"csrc/cutlass_extensions/common.cpp"
|
"csrc/cutlass_extensions/common.cpp"
|
||||||
"csrc/quantization/fp8/per_token_group_quant.cu")
|
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
|
||||||
|
"csrc/quantization/w8a8/int8/per_token_group_quant.cu")
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${VLLM_EXT_SRC}"
|
SRCS "${VLLM_EXT_SRC}"
|
||||||
@ -412,11 +424,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_int8.cu"
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_azp_sm90_int8.cu"
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
@ -440,12 +452,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||||
# CUDA 12.8 or later
|
# CUDA 12.8 or later
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu"
|
||||||
)
|
)
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
@ -470,12 +486,16 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||||
# require CUDA 12.8 or later
|
# require CUDA 12.8 or later
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
"csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
"csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu"
|
||||||
)
|
)
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
@ -506,7 +526,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# subtract out the archs that are already built for 3x
|
# subtract out the archs that are already built for 3x
|
||||||
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
|
||||||
if (SCALED_MM_2X_ARCHS)
|
if (SCALED_MM_2X_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
|
set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
|
||||||
@ -550,7 +570,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
|
|
||||||
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
|
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
|
||||||
# CUDA 12.8 or later
|
# CUDA 12.8 or later
|
||||||
cuda_archs_loose_intersection(FP4_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
@ -569,7 +593,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# FP4 Archs and flags
|
# FP4 Archs and flags
|
||||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "10.0a;10.1a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
@ -591,7 +619,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# CUTLASS MLA Archs and flags
|
# CUTLASS MLA Archs and flags
|
||||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||||
set(SRCS
|
set(SRCS
|
||||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||||
@ -617,7 +649,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
# if it's possible to compile MoE kernels that use its output.
|
# if it's possible to compile MoE kernels that use its output.
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
|
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
@ -635,9 +667,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
|
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
@ -656,9 +692,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||||
@ -675,9 +715,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
|
||||||
|
else()
|
||||||
|
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}")
|
||||||
|
endif()
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
|
set(SRCS "csrc/quantization/w8a8/cutlass/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${SRCS}"
|
SRCS "${SRCS}"
|
||||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||||
@ -963,6 +1007,7 @@ endif()
|
|||||||
# For CUDA we also build and ship some external projects.
|
# For CUDA we also build and ship some external projects.
|
||||||
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
include(cmake/external_projects/flashmla.cmake)
|
include(cmake/external_projects/flashmla.cmake)
|
||||||
|
include(cmake/external_projects/qutlass.cmake)
|
||||||
|
|
||||||
# vllm-flash-attn should be last as it overwrites some CMake functions
|
# vllm-flash-attn should be last as it overwrites some CMake functions
|
||||||
include(cmake/external_projects/vllm_flash_attn.cmake)
|
include(cmake/external_projects/vllm_flash_attn.cmake)
|
||||||
|
@ -21,6 +21,7 @@ Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundatio
|
|||||||
|
|
||||||
*Latest News* 🔥
|
*Latest News* 🔥
|
||||||
|
|
||||||
|
- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing).
|
||||||
- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
|
- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
|
||||||
- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing).
|
- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing).
|
||||||
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
||||||
@ -148,6 +149,7 @@ Compute Resources:
|
|||||||
- Trainy
|
- Trainy
|
||||||
- UC Berkeley
|
- UC Berkeley
|
||||||
- UC San Diego
|
- UC San Diego
|
||||||
|
- Volcengine
|
||||||
|
|
||||||
Slack Sponsor: Anyscale
|
Slack Sponsor: Anyscale
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ start_server() {
|
|||||||
local vllm_log=$4
|
local vllm_log=$4
|
||||||
local profile_dir=$5
|
local profile_dir=$5
|
||||||
|
|
||||||
pkill -if vllm
|
pkill -if "vllm serve" || true
|
||||||
|
|
||||||
# Define the common arguments as a bash array.
|
# Define the common arguments as a bash array.
|
||||||
# Each argument and its value are separate elements.
|
# Each argument and its value are separate elements.
|
||||||
@ -96,11 +96,11 @@ start_server() {
|
|||||||
# This correctly passes each element as a separate argument.
|
# This correctly passes each element as a separate argument.
|
||||||
if [[ -n "$profile_dir" ]]; then
|
if [[ -n "$profile_dir" ]]; then
|
||||||
# Start server with profiling enabled
|
# Start server with profiling enabled
|
||||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
|
VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir \
|
||||||
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||||
else
|
else
|
||||||
# Start server without profiling
|
# Start server without profiling
|
||||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 \
|
VLLM_SERVER_DEV_MODE=1 \
|
||||||
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||||
fi
|
fi
|
||||||
local server_pid=$!
|
local server_pid=$!
|
||||||
@ -139,7 +139,7 @@ run_benchmark() {
|
|||||||
echo "vllm_log: $vllm_log"
|
echo "vllm_log: $vllm_log"
|
||||||
echo
|
echo
|
||||||
rm -f $vllm_log
|
rm -f $vllm_log
|
||||||
pkill -if vllm
|
pkill -if "vllm serve" || true
|
||||||
|
|
||||||
echo "starting server..."
|
echo "starting server..."
|
||||||
# Call start_server without a profile_dir to avoid profiling overhead
|
# Call start_server without a profile_dir to avoid profiling overhead
|
||||||
@ -232,7 +232,7 @@ run_benchmark() {
|
|||||||
|
|
||||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
||||||
|
|
||||||
pkill -if vllm
|
pkill -if "vllm serve" || true
|
||||||
sleep 10
|
sleep 10
|
||||||
echo "===================="
|
echo "===================="
|
||||||
return 0
|
return 0
|
||||||
@ -308,6 +308,6 @@ if (( $(echo "$best_throughput > 0" | bc -l) )); then
|
|||||||
else
|
else
|
||||||
echo "No configuration met the latency requirements. Skipping final profiling run."
|
echo "No configuration met the latency requirements. Skipping final profiling run."
|
||||||
fi
|
fi
|
||||||
pkill -if vllm
|
pkill -if "vllm serve" || true
|
||||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
|
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
|
||||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
|
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from benchmark_utils import TimeCollector
|
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
|
|
||||||
|
@ -5,9 +5,9 @@ import time
|
|||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
from benchmark_utils import TimeCollector
|
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
DeviceConfig,
|
DeviceConfig,
|
||||||
@ -164,7 +164,7 @@ def invoke_main() -> None:
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--batched", action="store_true", help="consider time to prepare batch"
|
"--batched", action="store_true", help="consider time to prepare batch"
|
||||||
) # noqa: E501
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-iteration",
|
"--num-iteration",
|
||||||
type=int,
|
type=int,
|
||||||
|
@ -37,14 +37,13 @@ from typing import Optional
|
|||||||
import datasets
|
import datasets
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm.asyncio import tqdm
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from backend_request_func import (
|
from backend_request_func import (
|
||||||
ASYNC_REQUEST_FUNCS,
|
ASYNC_REQUEST_FUNCS,
|
||||||
RequestFuncInput,
|
RequestFuncInput,
|
||||||
RequestFuncOutput,
|
RequestFuncOutput,
|
||||||
)
|
)
|
||||||
|
from tqdm.asyncio import tqdm
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
@ -910,13 +909,13 @@ def create_argument_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer",
|
"--tokenizer",
|
||||||
type=str,
|
type=str,
|
||||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
help="Name or path of the tokenizer, if not using the default tokenizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer-mode",
|
"--tokenizer-mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="auto",
|
default="auto",
|
||||||
help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
help="Name or path of the tokenizer, if not using the default tokenizer.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-prompts",
|
"--num-prompts",
|
||||||
|
@ -17,7 +17,7 @@ from weight_shapes import WEIGHT_SHAPES
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_triton_block_scaled_mm,
|
||||||
)
|
)
|
||||||
from vllm.utils import FlexibleArgumentParser, cdiv
|
from vllm.utils import FlexibleArgumentParser, cdiv
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ def bench_fp8(
|
|||||||
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||||
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
||||||
),
|
),
|
||||||
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
|
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
|
||||||
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
||||||
),
|
),
|
||||||
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
||||||
|
@ -55,9 +55,7 @@ benchmark() {
|
|||||||
output_len=$2
|
output_len=$2
|
||||||
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
--model $model \
|
|
||||||
--port 8100 \
|
--port 8100 \
|
||||||
--max-model-len 10000 \
|
--max-model-len 10000 \
|
||||||
--gpu-memory-utilization 0.6 \
|
--gpu-memory-utilization 0.6 \
|
||||||
@ -65,9 +63,7 @@ benchmark() {
|
|||||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||||
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
--model $model \
|
|
||||||
--port 8200 \
|
--port 8200 \
|
||||||
--max-model-len 10000 \
|
--max-model-len 10000 \
|
||||||
--gpu-memory-utilization 0.6 \
|
--gpu-memory-utilization 0.6 \
|
||||||
|
@ -38,16 +38,12 @@ wait_for_server() {
|
|||||||
launch_chunked_prefill() {
|
launch_chunked_prefill() {
|
||||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
# disagg prefill
|
# disagg prefill
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
--model $model \
|
|
||||||
--port 8100 \
|
--port 8100 \
|
||||||
--max-model-len 10000 \
|
--max-model-len 10000 \
|
||||||
--enable-chunked-prefill \
|
--enable-chunked-prefill \
|
||||||
--gpu-memory-utilization 0.6 &
|
--gpu-memory-utilization 0.6 &
|
||||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
--model $model \
|
|
||||||
--port 8200 \
|
--port 8200 \
|
||||||
--max-model-len 10000 \
|
--max-model-len 10000 \
|
||||||
--enable-chunked-prefill \
|
--enable-chunked-prefill \
|
||||||
@ -62,18 +58,14 @@ launch_chunked_prefill() {
|
|||||||
launch_disagg_prefill() {
|
launch_disagg_prefill() {
|
||||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
# disagg prefill
|
# disagg prefill
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 \
|
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
--model $model \
|
|
||||||
--port 8100 \
|
--port 8100 \
|
||||||
--max-model-len 10000 \
|
--max-model-len 10000 \
|
||||||
--gpu-memory-utilization 0.6 \
|
--gpu-memory-utilization 0.6 \
|
||||||
--kv-transfer-config \
|
--kv-transfer-config \
|
||||||
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=1 python3 \
|
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||||
-m vllm.entrypoints.openai.api_server \
|
|
||||||
--model $model \
|
|
||||||
--port 8200 \
|
--port 8200 \
|
||||||
--max-model-len 10000 \
|
--max-model-len 10000 \
|
||||||
--gpu-memory-utilization 0.6 \
|
--gpu-memory-utilization 0.6 \
|
||||||
|
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||||
|
# All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
|
||||||
|
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"torch-bf16": dict(enabled=True),
|
||||||
|
"mxfp4": dict(no_a_quant=False, enabled=True),
|
||||||
|
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||||
|
return (
|
||||||
|
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||||
|
* group_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_weight_mxfp4(
|
||||||
|
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
|
||||||
|
):
|
||||||
|
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
|
||||||
|
b, forward_hadamard_matrix, method="abs_max"
|
||||||
|
)
|
||||||
|
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
|
||||||
|
return weight_hf_e2m1, weight_hf_scale_block
|
||||||
|
|
||||||
|
|
||||||
|
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
|
||||||
|
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
|
||||||
|
b, forward_hadamard_matrix, device
|
||||||
|
)
|
||||||
|
alpha = torch.tensor([1.0], device="cuda")
|
||||||
|
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
# Pre-quantize activation
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||||
|
a, forward_hadamard_matrix, method="abs_max"
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return matmul_mxf4_bf16_tn(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
# Quantize activation on-the-fly
|
||||||
|
def run():
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||||
|
a, forward_hadamard_matrix, method="abs_max"
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||||
|
return matmul_mxf4_bf16_tn(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32,
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048,
|
||||||
|
4096,
|
||||||
|
8192,
|
||||||
|
16384,
|
||||||
|
24576,
|
||||||
|
32768,
|
||||||
|
],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=_enabled,
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs MXFP4 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K, had_size):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||||
|
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
run_quant = build_mxfp4_runner(
|
||||||
|
cfg, a, b, forward_hadamard_matrix, dtype, device
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
out = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
out.append(KN)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
for had_size in [32, 64, 128]:
|
||||||
|
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
show_plots=True,
|
||||||
|
save_path=f"bench_mxfp4_res_n{N}_k{K}",
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
had_size=had_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||||
|
# All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
|
||||||
|
from vllm._custom_ops import fusedQuantizeNv
|
||||||
|
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"torch-bf16": dict(enabled=True),
|
||||||
|
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||||
|
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||||
|
return (
|
||||||
|
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||||
|
* group_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_weight_nvfp4(
|
||||||
|
b: torch.Tensor,
|
||||||
|
forward_hadamard_matrix: torch.Tensor,
|
||||||
|
global_scale: torch.Tensor,
|
||||||
|
device: str,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
):
|
||||||
|
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(
|
||||||
|
b, forward_hadamard_matrix, global_scale
|
||||||
|
)
|
||||||
|
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view(
|
||||||
|
-1, K // 16
|
||||||
|
)
|
||||||
|
return weight_hf_e2m1, weight_hf_scale_block
|
||||||
|
|
||||||
|
|
||||||
|
def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K):
|
||||||
|
alpha = torch.tensor([1.0], device="cuda")
|
||||||
|
global_scale = torch.tensor([1.0], device="cuda")
|
||||||
|
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(
|
||||||
|
b, forward_hadamard_matrix, global_scale, device, M, N, K
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
# Pre-quantize activation
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||||
|
a, forward_hadamard_matrix, global_scale
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||||
|
-1, K // 16
|
||||||
|
)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return ops.cutlass_scaled_fp4_mm(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
# Quantize activation on-the-fly
|
||||||
|
def run():
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||||
|
a, forward_hadamard_matrix, global_scale
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||||
|
-1, K // 16
|
||||||
|
)
|
||||||
|
return ops.cutlass_scaled_fp4_mm(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32,
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048,
|
||||||
|
4096,
|
||||||
|
8192,
|
||||||
|
16384,
|
||||||
|
24576,
|
||||||
|
32768,
|
||||||
|
],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=_enabled,
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs NVFP4 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K, had_size):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||||
|
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
run_quant = build_nvfp4_runner(
|
||||||
|
cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
out = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
out.append(KN)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
for had_size in [16, 32, 64, 128]:
|
||||||
|
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
show_plots=True,
|
||||||
|
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
had_size=had_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
@ -579,18 +579,22 @@ def main(args: argparse.Namespace):
|
|||||||
E = config.ffn_config.moe_num_experts
|
E = config.ffn_config.moe_num_experts
|
||||||
topk = config.ffn_config.moe_top_k
|
topk = config.ffn_config.moe_top_k
|
||||||
intermediate_size = config.ffn_config.ffn_hidden_size
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
elif config.architectures[0] == "JambaForCausalLM":
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
E = config.num_experts
|
E = config.num_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
elif config.architectures[0] in (
|
elif config.architectures[0] in (
|
||||||
"DeepseekV3ForCausalLM",
|
|
||||||
"DeepseekV2ForCausalLM",
|
"DeepseekV2ForCausalLM",
|
||||||
|
"DeepseekV3ForCausalLM",
|
||||||
|
"DeepseekV32ForCausalLM",
|
||||||
"Glm4MoeForCausalLM",
|
"Glm4MoeForCausalLM",
|
||||||
):
|
):
|
||||||
E = config.n_routed_experts
|
E = config.n_routed_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
elif config.architectures[0] in (
|
elif config.architectures[0] in (
|
||||||
"Qwen2MoeForCausalLM",
|
"Qwen2MoeForCausalLM",
|
||||||
"Qwen3MoeForCausalLM",
|
"Qwen3MoeForCausalLM",
|
||||||
@ -599,10 +603,18 @@ def main(args: argparse.Namespace):
|
|||||||
E = config.num_experts
|
E = config.num_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.moe_intermediate_size
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
|
||||||
|
text_config = config.get_text_config()
|
||||||
|
E = text_config.num_experts
|
||||||
|
topk = text_config.num_experts_per_tok
|
||||||
|
intermediate_size = text_config.moe_intermediate_size
|
||||||
|
hidden_size = text_config.hidden_size
|
||||||
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
||||||
E = config.num_experts
|
E = config.num_experts
|
||||||
topk = config.moe_topk[0]
|
topk = config.moe_topk[0]
|
||||||
intermediate_size = config.moe_intermediate_size[0]
|
intermediate_size = config.moe_intermediate_size[0]
|
||||||
|
hidden_size = config.hidden_size
|
||||||
else:
|
else:
|
||||||
# Support for llama4
|
# Support for llama4
|
||||||
config = config.get_text_config()
|
config = config.get_text_config()
|
||||||
@ -610,6 +622,7 @@ def main(args: argparse.Namespace):
|
|||||||
E = config.num_local_experts
|
E = config.num_local_experts
|
||||||
topk = config.num_experts_per_tok
|
topk = config.num_experts_per_tok
|
||||||
intermediate_size = config.intermediate_size
|
intermediate_size = config.intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
enable_ep = bool(args.enable_expert_parallel)
|
enable_ep = bool(args.enable_expert_parallel)
|
||||||
if enable_ep:
|
if enable_ep:
|
||||||
ensure_divisibility(E, args.tp_size, "Number of experts")
|
ensure_divisibility(E, args.tp_size, "Number of experts")
|
||||||
@ -618,7 +631,6 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
||||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
hidden_size = config.hidden_size
|
|
||||||
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
|
||||||
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
174
benchmarks/kernels/benchmark_reshape_and_cache.py
Normal file
174
benchmarks/kernels/benchmark_reshape_and_cache.py
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import (
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
FlexibleArgumentParser,
|
||||||
|
create_kv_caches_with_random,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_benchmark(
|
||||||
|
num_tokens: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_iters: int,
|
||||||
|
benchmark_mode: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> float:
|
||||||
|
"""Return latency (seconds) for given num_tokens."""
|
||||||
|
|
||||||
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
|
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||||
|
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
# create random key / value tensors [T, H, D].
|
||||||
|
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
||||||
|
value = torch.randn_like(key)
|
||||||
|
|
||||||
|
# prepare the slot mapping.
|
||||||
|
# each token is assigned a unique slot in the KV-cache.
|
||||||
|
num_slots = block_size * num_blocks
|
||||||
|
if num_tokens > num_slots:
|
||||||
|
raise ValueError("num_tokens cannot exceed the total number of cache slots")
|
||||||
|
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
key_caches, value_caches = create_kv_caches_with_random(
|
||||||
|
num_blocks,
|
||||||
|
block_size,
|
||||||
|
1, # num_layers
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
kv_cache_dtype,
|
||||||
|
dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
# to free unused memory
|
||||||
|
del key_caches, value_caches
|
||||||
|
|
||||||
|
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||||
|
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||||
|
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||||
|
|
||||||
|
function_under_test = lambda: ops.reshape_and_cache(
|
||||||
|
key, # noqa: F821
|
||||||
|
value, # noqa: F821
|
||||||
|
key_cache, # noqa: F821
|
||||||
|
value_cache, # noqa: F821
|
||||||
|
slot_mapping, # noqa: F821
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if benchmark_mode == "cudagraph":
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
function_under_test()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
function_under_test = lambda: g.replay()
|
||||||
|
|
||||||
|
def run_cuda_benchmark(n_iters: int) -> float:
|
||||||
|
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(n_iters):
|
||||||
|
function_under_test()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.perf_counter()
|
||||||
|
return (end - start) / n_iters
|
||||||
|
|
||||||
|
# warm-up
|
||||||
|
run_cuda_benchmark(3)
|
||||||
|
|
||||||
|
lat = run_cuda_benchmark(num_iters)
|
||||||
|
|
||||||
|
# free tensors to mitigate OOM when sweeping
|
||||||
|
del key, value, key_cache, value_cache, slot_mapping
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return lat
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
rows = []
|
||||||
|
for exp in range(1, 17):
|
||||||
|
n_tok = 2**exp
|
||||||
|
lat = run_benchmark(
|
||||||
|
num_tokens=n_tok,
|
||||||
|
num_heads=args.num_heads,
|
||||||
|
head_size=args.head_size,
|
||||||
|
block_size=args.block_size,
|
||||||
|
num_blocks=args.num_blocks,
|
||||||
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
|
num_iters=args.iters,
|
||||||
|
benchmark_mode=args.mode,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
rows.append([n_tok, lat * 1e6]) # convert to microseconds
|
||||||
|
|
||||||
|
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
|
||||||
|
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--num-heads", type=int, default=128)
|
||||||
|
parser.add_argument(
|
||||||
|
"--head-size",
|
||||||
|
type=int,
|
||||||
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
|
default=128,
|
||||||
|
)
|
||||||
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
|
parser.add_argument("--num-blocks", type=int, default=128 * 128)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="bfloat16",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8"],
|
||||||
|
default="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--iters", type=int, default=200)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
choices=["cudagraph", "no_graph"],
|
||||||
|
default="cudagraph",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
@ -1,5 +1,19 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Comprehensive 3-way SiLU Benchmark Suite
|
||||||
|
|
||||||
|
This benchmark compares three SiLU implementations:
|
||||||
|
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
|
||||||
|
2. Triton Kernel - Triton-based implementation
|
||||||
|
|
||||||
|
The suite generates detailed performance comparisons including:
|
||||||
|
- Memory bandwidth utilization
|
||||||
|
- Speedup ratios (baseline vs optimized implementations)
|
||||||
|
- Performance across different expert configurations and token distributions
|
||||||
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -7,7 +21,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
persistent_masked_m_silu_mul_quant,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
@ -94,6 +108,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
|
|||||||
num_parallel_tokens,
|
num_parallel_tokens,
|
||||||
group_size: int = 128,
|
group_size: int = 128,
|
||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
|
expert_offsets: torch.Tensor = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||||
|
|
||||||
@ -174,7 +189,7 @@ def silu_mul_fp8_quant_deep_gemm_triton(
|
|||||||
|
|
||||||
|
|
||||||
# Parse generation strategies
|
# Parse generation strategies
|
||||||
strategies = ["uniform", "max_t", "first_t"]
|
strategies = ["random_imbalanced", "uniform", "max_t"]
|
||||||
|
|
||||||
|
|
||||||
def benchmark(
|
def benchmark(
|
||||||
@ -195,15 +210,27 @@ def benchmark(
|
|||||||
current_platform.seed_everything(42 + seed_offset)
|
current_platform.seed_everything(42 + seed_offset)
|
||||||
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||||
|
|
||||||
if gen_strategy == "uniform":
|
if gen_strategy == "random_imbalanced":
|
||||||
r = torch.rand(size=(E,), device="cuda")
|
|
||||||
|
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
|
||||||
|
mean = total_tokens // n_e
|
||||||
|
min_max = mean // ratio
|
||||||
|
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
|
||||||
|
e[0] = min_max
|
||||||
|
r = torch.rand(size=(E - 1,))
|
||||||
|
r /= r.sum()
|
||||||
|
r *= total_tokens - min_max
|
||||||
|
r = r.round().long()
|
||||||
|
e[1:] = r.to(device=device)
|
||||||
|
return e
|
||||||
|
|
||||||
|
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
|
||||||
|
elif gen_strategy == "uniform":
|
||||||
|
r = torch.rand(size=(E,))
|
||||||
r /= r.sum()
|
r /= r.sum()
|
||||||
r *= total_tokens
|
r *= total_tokens
|
||||||
tokens_per_expert = r.int()
|
r = r.round().long()
|
||||||
tokens_per_expert = torch.minimum(
|
tokens_per_expert = r
|
||||||
tokens_per_expert,
|
|
||||||
torch.ones((E,), device=r.device, dtype=torch.int) * T,
|
|
||||||
)
|
|
||||||
elif gen_strategy == "max_t":
|
elif gen_strategy == "max_t":
|
||||||
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
||||||
tokens_per_expert.fill_(total_tokens / E)
|
tokens_per_expert.fill_(total_tokens / E)
|
||||||
@ -281,40 +308,34 @@ def benchmark(
|
|||||||
|
|
||||||
|
|
||||||
def create_comparison_plot(
|
def create_comparison_plot(
|
||||||
ratio, cuda_times, baseline_times, config_labels, strategy_name, id
|
ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
|
||||||
):
|
):
|
||||||
"""Create a comparison plot for a specific generation strategy"""
|
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
|
||||||
fig, ax = plt.subplots(1, 1, figsize=(16, 6))
|
|
||||||
|
|
||||||
# Configure x-axis positions
|
# Configure x-axis positions
|
||||||
x = np.arange(len(config_labels))
|
x = np.arange(len(config_labels))
|
||||||
width = 0.35
|
width = 0.25
|
||||||
|
|
||||||
# Execution Time plot (lower is better)
|
# Execution Time plot (lower is better)
|
||||||
|
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
|
||||||
ax.bar(
|
ax.bar(
|
||||||
x - width / 2, cuda_times, width, label="CUDA Kernel", alpha=0.8, color="blue"
|
x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
|
||||||
)
|
|
||||||
ax.bar(
|
|
||||||
x + width / 2,
|
|
||||||
baseline_times,
|
|
||||||
width,
|
|
||||||
label="Baseline",
|
|
||||||
alpha=0.8,
|
|
||||||
color="orange",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add speedup labels over each bar pair
|
# Add speedup labels over each bar trio
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
speedup = ratio[i]
|
triton_v2_speedup = ratios[i][1] # triton/v2
|
||||||
max_height = max(cuda_times[i], baseline_times[i])
|
max_height = max(silu_v2_times[i], triton_times[i])
|
||||||
|
|
||||||
|
# Triton/V2 speedup
|
||||||
ax.text(
|
ax.text(
|
||||||
x[i],
|
x[i] + width / 2,
|
||||||
max_height + max_height * 0.02,
|
max_height + max_height * 0.02,
|
||||||
f"{speedup:.2f}x",
|
f"{triton_v2_speedup:.2f}x",
|
||||||
ha="center",
|
ha="center",
|
||||||
va="bottom",
|
va="bottom",
|
||||||
fontweight="bold",
|
fontweight="bold",
|
||||||
fontsize=9,
|
fontsize=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.set_xlabel("Configuration")
|
ax.set_xlabel("Configuration")
|
||||||
@ -332,56 +353,75 @@ def create_comparison_plot(
|
|||||||
|
|
||||||
|
|
||||||
def create_combined_plot(all_results):
|
def create_combined_plot(all_results):
|
||||||
"""Create a combined plot with all strategies in one PNG"""
|
|
||||||
num_strategies = len(all_results)
|
num_strategies = len(all_results)
|
||||||
fig, axes = plt.subplots(num_strategies, 1, figsize=(20, 6 * num_strategies))
|
fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
|
||||||
|
|
||||||
if num_strategies == 1:
|
if num_strategies == 1:
|
||||||
axes = [axes]
|
axes = [axes]
|
||||||
|
|
||||||
for idx, (
|
for idx, (
|
||||||
strategy_name,
|
strategy_name,
|
||||||
ratio,
|
all_ratios,
|
||||||
cuda_times,
|
all_silu_v2_results,
|
||||||
baseline_times,
|
all_triton_results,
|
||||||
config_labels,
|
config_labels,
|
||||||
|
config_x_axis,
|
||||||
) in enumerate(all_results):
|
) in enumerate(all_results):
|
||||||
ax = axes[idx]
|
ax = axes[idx]
|
||||||
|
|
||||||
|
# Flatten the nested results to get bandwidth percentages for plotting
|
||||||
|
silu_v2_bandwidths = []
|
||||||
|
triton_bandwidths = []
|
||||||
|
flat_ratios = []
|
||||||
|
|
||||||
|
for config_results in all_silu_v2_results:
|
||||||
|
for result in config_results:
|
||||||
|
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
|
||||||
|
|
||||||
|
for config_results in all_triton_results:
|
||||||
|
for result in config_results:
|
||||||
|
triton_bandwidths.append(result[3]) # bandwidth percentage
|
||||||
|
|
||||||
|
for config_ratios in all_ratios:
|
||||||
|
for ratio in config_ratios:
|
||||||
|
flat_ratios.append(ratio)
|
||||||
|
|
||||||
# Configure x-axis positions
|
# Configure x-axis positions
|
||||||
x = np.arange(len(config_labels))
|
x = np.arange(len(config_labels))
|
||||||
width = 0.35
|
width = 0.25
|
||||||
|
|
||||||
# Execution Time plot (lower is better)
|
# Bandwidth utilization plot (higher is better)
|
||||||
ax.bar(
|
ax.bar(
|
||||||
x - width / 2,
|
x,
|
||||||
cuda_times,
|
silu_v2_bandwidths,
|
||||||
width,
|
width,
|
||||||
label="CUDA Kernel",
|
label="SiLU V2 (CUDA)",
|
||||||
alpha=0.8,
|
alpha=0.8,
|
||||||
color="blue",
|
color="blue",
|
||||||
)
|
)
|
||||||
ax.bar(
|
ax.bar(
|
||||||
x + width / 2,
|
x + width,
|
||||||
baseline_times,
|
triton_bandwidths,
|
||||||
width,
|
width,
|
||||||
label="Baseline",
|
label="Triton Kernel",
|
||||||
alpha=0.8,
|
alpha=0.8,
|
||||||
color="orange",
|
color="green",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add speedup labels over each bar pair
|
# Add speedup labels over each bar trio
|
||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
speedup = ratio[i]
|
triton_v2_speedup = flat_ratios[i] # triton/v2
|
||||||
max_height = max(cuda_times[i], baseline_times[i])
|
max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
|
||||||
|
|
||||||
|
# Triton/V2 speedup
|
||||||
ax.text(
|
ax.text(
|
||||||
x[i],
|
x[i] + width / 2,
|
||||||
max_height + max_height * 0.02,
|
max_height + max_height * 0.02,
|
||||||
f"{speedup:.2f}x",
|
f"{triton_v2_speedup:.2f}x",
|
||||||
ha="center",
|
ha="center",
|
||||||
va="bottom",
|
va="bottom",
|
||||||
fontweight="bold",
|
fontweight="bold",
|
||||||
fontsize=9,
|
fontsize=8,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax.set_xlabel("Configuration")
|
ax.set_xlabel("Configuration")
|
||||||
@ -395,7 +435,7 @@ def create_combined_plot(all_results):
|
|||||||
ax.grid(True, alpha=0.3)
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
filename = "../../silu_bench/silu_benchmark_combined.png"
|
filename = "silu_benchmark_combined_3way.png"
|
||||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
@ -405,7 +445,9 @@ def create_combined_plot(all_results):
|
|||||||
outer_dim = 7168
|
outer_dim = 7168
|
||||||
configs = [
|
configs = [
|
||||||
# DeepSeekV3 Configs
|
# DeepSeekV3 Configs
|
||||||
|
# (1, 56, 7168),
|
||||||
(8, 1024, 7168),
|
(8, 1024, 7168),
|
||||||
|
# (32, 56, 7168),
|
||||||
# DeepSeekV3 Configs
|
# DeepSeekV3 Configs
|
||||||
(32, 1024, 7168),
|
(32, 1024, 7168),
|
||||||
# DeepSeekV3 Configs
|
# DeepSeekV3 Configs
|
||||||
@ -417,6 +459,7 @@ num_warmups = 20
|
|||||||
|
|
||||||
strategy_descriptions = {
|
strategy_descriptions = {
|
||||||
"uniform": "Uniform Random",
|
"uniform": "Uniform Random",
|
||||||
|
"random_imbalanced": "Imbalanced Random",
|
||||||
"max_t": "Even Assignment",
|
"max_t": "Even Assignment",
|
||||||
"first_t": "experts[0] = T, experts[1:] = 0",
|
"first_t": "experts[0] = T, experts[1:] = 0",
|
||||||
}
|
}
|
||||||
@ -433,28 +476,31 @@ for id, strategy in enumerate(strategies):
|
|||||||
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
||||||
print(f"{'=' * 60}")
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
# Collect benchmark data for both algorithms
|
# Collect benchmark data for all three algorithms
|
||||||
config_labels = []
|
config_labels = []
|
||||||
config_x_axis = []
|
config_x_axis = []
|
||||||
all_cuda_results = []
|
all_silu_v2_results = []
|
||||||
all_baseline_results = []
|
all_triton_results = []
|
||||||
all_ratios = []
|
all_ratios = []
|
||||||
|
|
||||||
for E, T, H in configs:
|
for E, T, H in configs:
|
||||||
total_tokens_config = [8 * E, 16 * E, 32 * E, 64 * E, 128 * E, 256 * E]
|
total_tokens_config = []
|
||||||
|
for i in [8, 16, 32, 64, 128, 256, 512]:
|
||||||
|
if i <= T:
|
||||||
|
total_tokens_config.append(i * E)
|
||||||
config_x_axis.append(total_tokens_config)
|
config_x_axis.append(total_tokens_config)
|
||||||
|
|
||||||
cuda_results = []
|
silu_v2_results = []
|
||||||
baseline_results = []
|
triton_results = []
|
||||||
ratios = []
|
ratios = []
|
||||||
|
|
||||||
for total_tokens in total_tokens_config:
|
for total_tokens in total_tokens_config:
|
||||||
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
||||||
config_labels.append(config_label)
|
config_labels.append(config_label)
|
||||||
|
|
||||||
# CUDA kernel results
|
# SiLU V2 (CUDA kernel) results
|
||||||
time_ms_cuda, gflops, gbps, perc = benchmark(
|
time_ms_silu_v2, gflops, gbps, perc = benchmark(
|
||||||
silu_mul_fp8_quant_deep_gemm_cuda,
|
persistent_masked_m_silu_mul_quant,
|
||||||
E,
|
E,
|
||||||
T,
|
T,
|
||||||
H,
|
H,
|
||||||
@ -463,9 +509,9 @@ for id, strategy in enumerate(strategies):
|
|||||||
num_warmups=num_warmups,
|
num_warmups=num_warmups,
|
||||||
gen_strategy=strategy,
|
gen_strategy=strategy,
|
||||||
)
|
)
|
||||||
cuda_results.append((time_ms_cuda, gflops, gbps, perc))
|
silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
|
||||||
|
|
||||||
# Baseline results
|
# Triton kernel results
|
||||||
time_ms_triton, gflops, gbps, perc = benchmark(
|
time_ms_triton, gflops, gbps, perc = benchmark(
|
||||||
silu_mul_fp8_quant_deep_gemm_triton,
|
silu_mul_fp8_quant_deep_gemm_triton,
|
||||||
E,
|
E,
|
||||||
@ -476,12 +522,20 @@ for id, strategy in enumerate(strategies):
|
|||||||
num_warmups=num_warmups,
|
num_warmups=num_warmups,
|
||||||
gen_strategy=strategy,
|
gen_strategy=strategy,
|
||||||
)
|
)
|
||||||
baseline_results.append((time_ms_triton, gflops, gbps, perc))
|
triton_results.append((time_ms_triton, gflops, gbps, perc))
|
||||||
ratios.append(time_ms_triton / time_ms_cuda)
|
|
||||||
|
|
||||||
print(f"Completed: {config_label}")
|
# Calculate speedup ratios (triton baseline / implementation)
|
||||||
all_cuda_results.append(cuda_results)
|
triton_v2_ratio = time_ms_triton / time_ms_silu_v2
|
||||||
all_baseline_results.append(baseline_results)
|
ratios.append(triton_v2_ratio)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Completed: {config_label}:"
|
||||||
|
f" V2: {time_ms_silu_v2:.3f}ms,"
|
||||||
|
f" Triton: {time_ms_triton:.3f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_silu_v2_results.append(silu_v2_results)
|
||||||
|
all_triton_results.append(triton_results)
|
||||||
all_ratios.append(ratios)
|
all_ratios.append(ratios)
|
||||||
|
|
||||||
# Store results for combined plotting
|
# Store results for combined plotting
|
||||||
@ -489,8 +543,8 @@ for id, strategy in enumerate(strategies):
|
|||||||
(
|
(
|
||||||
strategy_descriptions[strategy],
|
strategy_descriptions[strategy],
|
||||||
all_ratios,
|
all_ratios,
|
||||||
all_cuda_results,
|
all_silu_v2_results,
|
||||||
all_baseline_results,
|
all_triton_results,
|
||||||
config_labels,
|
config_labels,
|
||||||
config_x_axis,
|
config_x_axis,
|
||||||
)
|
)
|
||||||
@ -498,15 +552,18 @@ for id, strategy in enumerate(strategies):
|
|||||||
|
|
||||||
# Print summary table for this strategy
|
# Print summary table for this strategy
|
||||||
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
||||||
print(f"{'Config':<20} {'CUDA Time(ms)':<12} {'Base Time(ms)':<12} {'Speedup':<8}")
|
print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
|
||||||
print("-" * 60)
|
print("-" * 90)
|
||||||
|
|
||||||
for i, (E, T, H) in enumerate(configs):
|
for i, (E, T, H) in enumerate(configs):
|
||||||
speedup = baseline_results[i][0] / cuda_results[i][0]
|
# Get the first result for each config (simplifying for summary)
|
||||||
|
v2_time = silu_v2_results[i][0]
|
||||||
|
triton_time = triton_results[i][0]
|
||||||
|
triton_v2_speedup = triton_time / v2_time
|
||||||
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
||||||
print(
|
print(
|
||||||
f"{config_label:<20} {cuda_results[i][0]:8.5f} "
|
f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
|
||||||
f"{baseline_results[i][0]:8.5f} {speedup:6.2f}x"
|
f"{triton_v2_speedup:8.2f}x"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -514,15 +571,14 @@ def create_total_tokens_plot(all_results):
|
|||||||
num_strategies = len(all_results)
|
num_strategies = len(all_results)
|
||||||
num_configs = len(configs)
|
num_configs = len(configs)
|
||||||
|
|
||||||
# Create side-by-side subplots: 2 columns for speedup and bandwidth percentage
|
|
||||||
fig, axs = plt.subplots(
|
fig, axs = plt.subplots(
|
||||||
num_strategies, num_configs * 2, figsize=(28, 6 * num_strategies)
|
num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add main title to the entire figure
|
# Add main title to the entire figure
|
||||||
fig.suptitle(
|
fig.suptitle(
|
||||||
"Performance Analysis: Speedup vs Bandwidth Utilization (Triton & CUDA)",
|
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
|
||||||
fontsize=16,
|
fontsize=18,
|
||||||
fontweight="bold",
|
fontweight="bold",
|
||||||
y=0.98,
|
y=0.98,
|
||||||
)
|
)
|
||||||
@ -539,8 +595,8 @@ def create_total_tokens_plot(all_results):
|
|||||||
(
|
(
|
||||||
strategy_name,
|
strategy_name,
|
||||||
all_ratios,
|
all_ratios,
|
||||||
all_cuda_results,
|
all_silu_v2_results,
|
||||||
all_baseline_results,
|
all_triton_results,
|
||||||
config_labels,
|
config_labels,
|
||||||
config_x_axis,
|
config_x_axis,
|
||||||
) = result
|
) = result
|
||||||
@ -555,42 +611,54 @@ def create_total_tokens_plot(all_results):
|
|||||||
ratios = all_ratios[config_idx]
|
ratios = all_ratios[config_idx]
|
||||||
total_tokens_values = config_x_axis[config_idx]
|
total_tokens_values = config_x_axis[config_idx]
|
||||||
|
|
||||||
# Extract CUDA and Triton bandwidth percentages
|
# Extract speedup ratios
|
||||||
cuda_bandwidth_percentages = [
|
triton_v2_ratios = [ratio for ratio in ratios]
|
||||||
result[3] for result in all_cuda_results[config_idx]
|
|
||||||
|
# Extract bandwidth percentages for all implementations
|
||||||
|
v2_bandwidth_percentages = [
|
||||||
|
result[3] for result in all_silu_v2_results[config_idx]
|
||||||
]
|
]
|
||||||
triton_bandwidth_percentages = [
|
triton_bandwidth_percentages = [
|
||||||
result[3] for result in all_baseline_results[config_idx]
|
result[3] for result in all_triton_results[config_idx]
|
||||||
]
|
]
|
||||||
|
|
||||||
# Plot speedup ratios vs total tokens (left plot)
|
# Plot speedup ratios vs total tokens (left plot)
|
||||||
ax_speedup.plot(
|
ax_speedup.plot(
|
||||||
total_tokens_values, ratios, "bo-", linewidth=3, markersize=8
|
total_tokens_values,
|
||||||
|
triton_v2_ratios,
|
||||||
|
"go-",
|
||||||
|
linewidth=3,
|
||||||
|
markersize=8,
|
||||||
|
label="Triton/V2 Speedup",
|
||||||
)
|
)
|
||||||
ax_speedup.set_title(
|
ax_speedup.set_title(
|
||||||
f"{strategy_name}\nSpeedup (CUDA/Triton)\nE={E}, T={T}, H={H}",
|
f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
|
||||||
fontsize=12,
|
fontsize=12,
|
||||||
fontweight="bold",
|
fontweight="bold",
|
||||||
)
|
)
|
||||||
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||||
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
||||||
|
ax_speedup.legend(prop={"weight": "bold"})
|
||||||
ax_speedup.grid(True, alpha=0.3)
|
ax_speedup.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Plot bandwidth utilization (right plot)
|
||||||
ax_bandwidth.plot(
|
ax_bandwidth.plot(
|
||||||
total_tokens_values,
|
total_tokens_values,
|
||||||
cuda_bandwidth_percentages,
|
v2_bandwidth_percentages,
|
||||||
"ro-",
|
"o-",
|
||||||
linewidth=3,
|
linewidth=3,
|
||||||
markersize=8,
|
markersize=8,
|
||||||
label="CUDA",
|
label="SiLU V2",
|
||||||
|
color="blue",
|
||||||
)
|
)
|
||||||
ax_bandwidth.plot(
|
ax_bandwidth.plot(
|
||||||
total_tokens_values,
|
total_tokens_values,
|
||||||
triton_bandwidth_percentages,
|
triton_bandwidth_percentages,
|
||||||
"go-",
|
"o-",
|
||||||
linewidth=3,
|
linewidth=3,
|
||||||
markersize=8,
|
markersize=8,
|
||||||
label="Triton",
|
label="Triton",
|
||||||
|
color="green",
|
||||||
)
|
)
|
||||||
ax_bandwidth.set_title(
|
ax_bandwidth.set_title(
|
||||||
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
||||||
@ -618,38 +686,12 @@ def create_total_tokens_plot(all_results):
|
|||||||
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||||
label.set_fontweight("bold")
|
label.set_fontweight("bold")
|
||||||
|
|
||||||
# Add value labels on speedup points
|
# Add value labels on Triton/V2 speedup points
|
||||||
for x, y in zip(total_tokens_values, ratios):
|
for x, y in zip(total_tokens_values, triton_v2_ratios):
|
||||||
ax_speedup.annotate(
|
ax_speedup.annotate(
|
||||||
f"{y:.2f}x",
|
f"{y:.2f}x",
|
||||||
(x, y),
|
(x, y),
|
||||||
textcoords="offset points",
|
textcoords="offset points",
|
||||||
xytext=(0, 12),
|
|
||||||
ha="center",
|
|
||||||
fontsize=10,
|
|
||||||
fontweight="bold",
|
|
||||||
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add value labels on CUDA bandwidth points
|
|
||||||
for x, y in zip(total_tokens_values, cuda_bandwidth_percentages):
|
|
||||||
ax_bandwidth.annotate(
|
|
||||||
f"{y:.1f}%",
|
|
||||||
(x, y),
|
|
||||||
textcoords="offset points",
|
|
||||||
xytext=(0, 12),
|
|
||||||
ha="center",
|
|
||||||
fontsize=9,
|
|
||||||
fontweight="bold",
|
|
||||||
bbox=dict(boxstyle="round,pad=0.2", facecolor="red", alpha=0.3),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add value labels on Triton bandwidth points
|
|
||||||
for x, y in zip(total_tokens_values, triton_bandwidth_percentages):
|
|
||||||
ax_bandwidth.annotate(
|
|
||||||
f"{y:.1f}%",
|
|
||||||
(x, y),
|
|
||||||
textcoords="offset points",
|
|
||||||
xytext=(0, -15),
|
xytext=(0, -15),
|
||||||
ha="center",
|
ha="center",
|
||||||
fontsize=9,
|
fontsize=9,
|
||||||
@ -659,17 +701,20 @@ def create_total_tokens_plot(all_results):
|
|||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.subplots_adjust(top=0.93) # Make room for main title
|
plt.subplots_adjust(top=0.93) # Make room for main title
|
||||||
filename = "silu_benchmark_total_tokens.png"
|
filename = "silu_benchmark_total_tokens_3way.png"
|
||||||
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
# Create combined plot with all strategies
|
# Create comprehensive 3-way comparison plots
|
||||||
combined_plot_filename = create_total_tokens_plot(all_results)
|
combined_plot_filename = create_combined_plot(all_results)
|
||||||
|
total_tokens_plot_filename = create_total_tokens_plot(all_results)
|
||||||
|
|
||||||
print(f"\n{'=' * 60}")
|
print(f"\n{'=' * 80}")
|
||||||
print("Benchmark Complete!")
|
print("3-Way Benchmark Suite Complete!")
|
||||||
print(f"Generated combined plot: {combined_plot_filename}")
|
print(f"Generated combined comparison plot: {combined_plot_filename}")
|
||||||
print(f"{'=' * 60}")
|
print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
|
||||||
|
print("Compared: SiLU V2 (CUDA), and Triton implementations")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
@ -14,7 +14,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
_w8a8_block_fp8_matmul,
|
_w8a8_triton_block_scaled_mm,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
@ -83,7 +83,7 @@ def w8a8_block_matmul(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if A.dtype == torch.float8_e4m3fn:
|
if A.dtype == torch.float8_e4m3fn:
|
||||||
kernel = _w8a8_block_fp8_matmul
|
kernel = _w8a8_triton_block_scaled_mm
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
# fmt: off
|
|
||||||
# ruff: noqa: E501
|
# ruff: noqa: E501
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -9,7 +8,7 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
per_token_group_quant_fp8,
|
per_token_group_quant_fp8,
|
||||||
w8a8_block_fp8_matmul,
|
w8a8_triton_block_scaled_mm,
|
||||||
)
|
)
|
||||||
from vllm.triton_utils import triton
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils.deep_gemm import (
|
from vllm.utils.deep_gemm import (
|
||||||
@ -20,19 +19,21 @@ from vllm.utils.deep_gemm import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def benchmark_shape(m: int,
|
def benchmark_shape(
|
||||||
n: int,
|
m: int,
|
||||||
k: int,
|
n: int,
|
||||||
warmup: int = 100,
|
k: int,
|
||||||
repeat: int = 10000,
|
warmup: int = 100,
|
||||||
verbose: bool = False) -> dict:
|
repeat: int = 10000,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> dict:
|
||||||
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||||
|
|
||||||
# Create test tensors
|
# Create test tensors
|
||||||
A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Reference result in BF16
|
# Reference result in BF16
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -49,34 +50,39 @@ def benchmark_shape(m: int,
|
|||||||
# Pre-quantize A for all implementations
|
# Pre-quantize A for all implementations
|
||||||
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
||||||
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||||
C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||||
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||||
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||||
A, block_size[1], column_major_scales=True)
|
A, block_size[1], column_major_scales=True
|
||||||
|
)
|
||||||
|
|
||||||
# === DeepGEMM Implementation ===
|
# === DeepGEMM Implementation ===
|
||||||
def deepgemm_gemm():
|
def deepgemm_gemm():
|
||||||
fp8_gemm_nt((A_deepgemm, A_scale_deepgemm),
|
fp8_gemm_nt(
|
||||||
(B_deepgemm, B_scale_deepgemm),
|
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
|
||||||
C_deepgemm)
|
)
|
||||||
return C_deepgemm
|
return C_deepgemm
|
||||||
|
|
||||||
# === vLLM Triton Implementation ===
|
# === vLLM Triton Implementation ===
|
||||||
def vllm_triton_gemm():
|
def vllm_triton_gemm():
|
||||||
return w8a8_block_fp8_matmul(A_vllm,
|
return w8a8_triton_block_scaled_mm(
|
||||||
B_vllm,
|
A_vllm,
|
||||||
A_scale_vllm,
|
B_vllm,
|
||||||
B_scale_vllm,
|
A_scale_vllm,
|
||||||
block_size,
|
B_scale_vllm,
|
||||||
output_dtype=torch.bfloat16)
|
block_size,
|
||||||
|
output_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
# === vLLM CUTLASS Implementation ===
|
# === vLLM CUTLASS Implementation ===
|
||||||
def vllm_cutlass_gemm():
|
def vllm_cutlass_gemm():
|
||||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
return ops.cutlass_scaled_mm(
|
||||||
B_vllm.T,
|
A_vllm_cutlass,
|
||||||
scale_a=A_scale_vllm_cutlass,
|
B_vllm.T,
|
||||||
scale_b=B_scale_vllm.T,
|
scale_a=A_scale_vllm_cutlass,
|
||||||
out_dtype=torch.bfloat16)
|
scale_b=B_scale_vllm.T,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
# Run correctness check first
|
# Run correctness check first
|
||||||
if verbose:
|
if verbose:
|
||||||
@ -93,26 +99,23 @@ def benchmark_shape(m: int,
|
|||||||
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||||
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||||
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||||
print("vLLM Triton vs DeepGEMM difference: "
|
print(
|
||||||
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}")
|
"vLLM Triton vs DeepGEMM difference: "
|
||||||
print("vLLM CUTLASS vs DeepGEMM difference: "
|
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
|
||||||
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}")
|
)
|
||||||
|
print(
|
||||||
|
"vLLM CUTLASS vs DeepGEMM difference: "
|
||||||
|
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
# Benchmark implementations
|
# Benchmark implementations
|
||||||
implementations = {
|
implementations = {
|
||||||
"DeepGEMM": deepgemm_gemm,
|
"DeepGEMM": deepgemm_gemm,
|
||||||
"vLLM Triton": vllm_triton_gemm,
|
"vLLM Triton": vllm_triton_gemm,
|
||||||
"vLLM CUTLASS": vllm_cutlass_gemm
|
"vLLM CUTLASS": vllm_cutlass_gemm,
|
||||||
}
|
}
|
||||||
|
|
||||||
benchmark_results = {
|
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
|
||||||
"shape": {
|
|
||||||
"m": m,
|
|
||||||
"n": n,
|
|
||||||
"k": k
|
|
||||||
},
|
|
||||||
"implementations": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, func in implementations.items():
|
for name, func in implementations.items():
|
||||||
# Warmup
|
# Warmup
|
||||||
@ -140,38 +143,36 @@ def benchmark_shape(m: int,
|
|||||||
"tflops": tflops,
|
"tflops": tflops,
|
||||||
"gb_s": gb_s,
|
"gb_s": gb_s,
|
||||||
"diff": {
|
"diff": {
|
||||||
"DeepGEMM":
|
"DeepGEMM": 0.0
|
||||||
0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm),
|
if name == "DeepGEMM"
|
||||||
"Reference":
|
else calc_diff(func(), C_deepgemm),
|
||||||
deepgemm_diff if name == "DeepGEMM" else
|
"Reference": deepgemm_diff
|
||||||
(vllm_triton_diff
|
if name == "DeepGEMM"
|
||||||
if name == "vLLM Triton" else vllm_cutlass_diff)
|
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
|
||||||
f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate speedups
|
# Calculate speedups
|
||||||
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||||
for name, data in benchmark_results["implementations"].items():
|
for name, data in benchmark_results["implementations"].items():
|
||||||
if name != "DeepGEMM":
|
if name != "DeepGEMM":
|
||||||
speedup = baseline / data["time_ms"]
|
speedup = baseline / data["time_ms"]
|
||||||
benchmark_results["implementations"][name][
|
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
|
||||||
"speedup_vs_deepgemm"] = speedup
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"DeepGEMM is {1/speedup:.2f}x "
|
print(
|
||||||
f"{'faster' if 1/speedup > 1 else 'slower'} than {name}")
|
f"DeepGEMM is {1 / speedup:.2f}x "
|
||||||
|
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
|
||||||
|
)
|
||||||
|
|
||||||
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][
|
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
|
||||||
"time_ms"]
|
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||||
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][
|
|
||||||
"time_ms"]
|
|
||||||
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||||
benchmark_results["implementations"]["vLLM CUTLASS"][
|
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
|
||||||
"speedup_vs_triton"] = cutlass_vs_triton
|
cutlass_vs_triton
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||||
@ -183,8 +184,7 @@ def benchmark_shape(m: int,
|
|||||||
|
|
||||||
def format_table_row(values, widths):
|
def format_table_row(values, widths):
|
||||||
"""Format a row with specified column widths."""
|
"""Format a row with specified column widths."""
|
||||||
return "| " + " | ".join(f"{val:{w}}"
|
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
|
||||||
for val, w in zip(values, widths)) + " |"
|
|
||||||
|
|
||||||
|
|
||||||
def print_table(headers, rows, title=None):
|
def print_table(headers, rows, title=None):
|
||||||
@ -292,38 +292,50 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
for result in all_results:
|
for result in all_results:
|
||||||
shape = result["shape"]
|
shape = result["shape"]
|
||||||
impl_data = result["implementations"]["DeepGEMM"]
|
impl_data = result["implementations"]["DeepGEMM"]
|
||||||
deepgemm_rows.append([
|
deepgemm_rows.append(
|
||||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
[
|
||||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}"
|
shape["m"],
|
||||||
])
|
shape["n"],
|
||||||
|
shape["k"],
|
||||||
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print_table(deepgemm_headers,
|
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
|
||||||
deepgemm_rows,
|
|
||||||
title="DeepGEMM Implementation:")
|
|
||||||
|
|
||||||
# Print vLLM Triton table
|
# Print vLLM Triton table
|
||||||
triton_headers = [
|
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
|
||||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"
|
|
||||||
]
|
|
||||||
triton_rows = []
|
triton_rows = []
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
shape = result["shape"]
|
shape = result["shape"]
|
||||||
impl_data = result["implementations"]["vLLM Triton"]
|
impl_data = result["implementations"]["vLLM Triton"]
|
||||||
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||||
triton_rows.append([
|
triton_rows.append(
|
||||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
[
|
||||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
shape["m"],
|
||||||
format_speedup(speedup)
|
shape["n"],
|
||||||
])
|
shape["k"],
|
||||||
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
format_speedup(speedup),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print_table(triton_headers,
|
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
|
||||||
triton_rows,
|
|
||||||
title="vLLM Triton Implementation:")
|
|
||||||
|
|
||||||
# Print vLLM CUTLASS table
|
# Print vLLM CUTLASS table
|
||||||
cutlass_headers = [
|
cutlass_headers = [
|
||||||
"m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM",
|
"m",
|
||||||
"vs Triton"
|
"n",
|
||||||
|
"k",
|
||||||
|
"Time (μs)",
|
||||||
|
"TFLOPS",
|
||||||
|
"GB/s",
|
||||||
|
"vs DeepGEMM",
|
||||||
|
"vs Triton",
|
||||||
]
|
]
|
||||||
cutlass_rows = []
|
cutlass_rows = []
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
@ -331,28 +343,27 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
impl_data = result["implementations"]["vLLM CUTLASS"]
|
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||||
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||||
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||||
cutlass_rows.append([
|
cutlass_rows.append(
|
||||||
shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}",
|
[
|
||||||
f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}",
|
shape["m"],
|
||||||
format_speedup(vs_deepgemm),
|
shape["n"],
|
||||||
format_speedup(vs_triton)
|
shape["k"],
|
||||||
])
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
format_speedup(vs_deepgemm),
|
||||||
|
format_speedup(vs_triton),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
print_table(cutlass_headers,
|
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
|
||||||
cutlass_rows,
|
|
||||||
title="vLLM CUTLASS Implementation:")
|
|
||||||
|
|
||||||
# Calculate and print averages
|
# Calculate and print averages
|
||||||
print("\n===== AVERAGE PERFORMANCE =====")
|
print("\n===== AVERAGE PERFORMANCE =====")
|
||||||
|
|
||||||
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||||
avg_metrics = {
|
avg_metrics = {
|
||||||
impl: {
|
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
|
||||||
"tflops": 0,
|
|
||||||
"gb_s": 0,
|
|
||||||
"time_ms": 0
|
|
||||||
}
|
|
||||||
for impl in implementations
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
@ -370,9 +381,9 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||||
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||||
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||||
avg_rows.append([
|
avg_rows.append(
|
||||||
impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"
|
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
|
||||||
])
|
)
|
||||||
|
|
||||||
print_table(avg_headers, avg_rows)
|
print_table(avg_headers, avg_rows)
|
||||||
|
|
||||||
@ -380,21 +391,19 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
avg_speedups = {
|
avg_speedups = {
|
||||||
"DeepGEMM vs vLLM Triton": 0,
|
"DeepGEMM vs vLLM Triton": 0,
|
||||||
"DeepGEMM vs vLLM CUTLASS": 0,
|
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||||
"vLLM CUTLASS vs vLLM Triton": 0
|
"vLLM CUTLASS vs vLLM Triton": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||||
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||||
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][
|
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||||
"time_ms"]
|
|
||||||
|
|
||||||
avg_speedups[
|
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||||
"DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||||
avg_speedups[
|
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
|
||||||
"DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
vllm_triton_time / vllm_cutlass_time
|
||||||
avg_speedups[
|
)
|
||||||
"vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time
|
|
||||||
|
|
||||||
print("\n===== AVERAGE SPEEDUPS =====")
|
print("\n===== AVERAGE SPEEDUPS =====")
|
||||||
speedup_headers = ["Comparison", "Speedup"]
|
speedup_headers = ["Comparison", "Speedup"]
|
||||||
@ -412,8 +421,7 @@ def run_benchmarks(verbose: bool = False):
|
|||||||
|
|
||||||
for result in all_results:
|
for result in all_results:
|
||||||
for impl in implementations:
|
for impl in implementations:
|
||||||
avg_diff[impl] += result["implementations"][impl]["diff"][
|
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
|
||||||
"Reference"]
|
|
||||||
|
|
||||||
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||||
diff_rows = []
|
diff_rows = []
|
||||||
|
@ -13,7 +13,7 @@ from datetime import datetime
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
from typing import NamedTuple, Optional, Union
|
from typing import NamedTuple, Union
|
||||||
|
|
||||||
import aiohttp # type: ignore
|
import aiohttp # type: ignore
|
||||||
import numpy as np # type: ignore
|
import numpy as np # type: ignore
|
||||||
@ -46,9 +46,9 @@ class ConversationSampling(str, Enum):
|
|||||||
|
|
||||||
class ClientArgs(NamedTuple):
|
class ClientArgs(NamedTuple):
|
||||||
seed: int
|
seed: int
|
||||||
max_num_requests: Optional[int]
|
max_num_requests: int | None
|
||||||
skip_first_turn: bool
|
skip_first_turn: bool
|
||||||
max_turns: Optional[int]
|
max_turns: int | None
|
||||||
max_active_conversations: int
|
max_active_conversations: int
|
||||||
verbose: bool
|
verbose: bool
|
||||||
print_content: bool
|
print_content: bool
|
||||||
@ -109,9 +109,9 @@ class RequestStats(NamedTuple):
|
|||||||
|
|
||||||
class MetricStats:
|
class MetricStats:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.min: Optional[float] = None
|
self.min: float | None = None
|
||||||
self.max: Optional[float] = None
|
self.max: float | None = None
|
||||||
self.avg: Optional[float] = None
|
self.avg: float | None = None
|
||||||
self.sum = 0.0
|
self.sum = 0.0
|
||||||
self.count = 0
|
self.count = 0
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ class MovingAverage:
|
|||||||
self.index = 0
|
self.index = 0
|
||||||
self.sum = 0.0
|
self.sum = 0.0
|
||||||
self.count = 0
|
self.count = 0
|
||||||
self.avg: Optional[float] = None
|
self.avg: float | None = None
|
||||||
|
|
||||||
def update(self, new_value: float) -> None:
|
def update(self, new_value: float) -> None:
|
||||||
if self.count < self.window_size:
|
if self.count < self.window_size:
|
||||||
@ -198,14 +198,6 @@ class DebugStats:
|
|||||||
self.logger.info("-" * 50)
|
self.logger.info("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
# Must support Python 3.8, we can't use str.removeprefix(prefix)
|
|
||||||
# introduced in Python 3.9
|
|
||||||
def remove_prefix(text: str, prefix: str) -> str:
|
|
||||||
if text.startswith(prefix):
|
|
||||||
return text[len(prefix) :]
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def nanosec_to_millisec(value: float) -> float:
|
def nanosec_to_millisec(value: float) -> float:
|
||||||
return value / 1000000.0
|
return value / 1000000.0
|
||||||
|
|
||||||
@ -220,8 +212,8 @@ async def send_request(
|
|||||||
chat_url: str,
|
chat_url: str,
|
||||||
model: str,
|
model: str,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
min_tokens: Optional[int] = None,
|
min_tokens: int | None = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: int | None = None,
|
||||||
) -> ServerResponse:
|
) -> ServerResponse:
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@ -250,9 +242,9 @@ async def send_request(
|
|||||||
timeout = aiohttp.ClientTimeout(total=timeout_sec)
|
timeout = aiohttp.ClientTimeout(total=timeout_sec)
|
||||||
|
|
||||||
valid_response = True
|
valid_response = True
|
||||||
ttft: Optional[float] = None
|
ttft: float | None = None
|
||||||
chunk_delay: list[int] = []
|
chunk_delay: list[int] = []
|
||||||
latency: Optional[float] = None
|
latency: float | None = None
|
||||||
first_chunk = ""
|
first_chunk = ""
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
|
|
||||||
@ -269,7 +261,7 @@ async def send_request(
|
|||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
if chunk == "[DONE]":
|
if chunk == "[DONE]":
|
||||||
# End of stream
|
# End of stream
|
||||||
latency = time.perf_counter_ns() - start_time
|
latency = time.perf_counter_ns() - start_time
|
||||||
@ -364,7 +356,7 @@ async def send_turn(
|
|||||||
req_args: RequestArgs,
|
req_args: RequestArgs,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
verify_output: bool,
|
verify_output: bool,
|
||||||
) -> Optional[RequestStats]:
|
) -> RequestStats | None:
|
||||||
assert messages_to_use > 0
|
assert messages_to_use > 0
|
||||||
assert messages_to_use <= len(conversation_messages)
|
assert messages_to_use <= len(conversation_messages)
|
||||||
|
|
||||||
@ -769,7 +761,7 @@ def get_client_config(
|
|||||||
"Number of conversations must be equal or larger than the number of clients"
|
"Number of conversations must be equal or larger than the number of clients"
|
||||||
)
|
)
|
||||||
|
|
||||||
max_req_per_client: Optional[int] = None
|
max_req_per_client: int | None = None
|
||||||
if args.max_num_requests is not None:
|
if args.max_num_requests is not None:
|
||||||
# Max number of requests per client
|
# Max number of requests per client
|
||||||
req_per_client = args.max_num_requests // args.num_clients
|
req_per_client = args.max_num_requests // args.num_clients
|
||||||
@ -1032,7 +1024,7 @@ def process_statistics(
|
|||||||
warmup_percentages: list[float],
|
warmup_percentages: list[float],
|
||||||
test_params: dict,
|
test_params: dict,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
gen_conv_args: Optional[GenConvArgs] = None,
|
gen_conv_args: GenConvArgs | None = None,
|
||||||
excel_output: bool = False,
|
excel_output: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if len(client_metrics) == 0:
|
if len(client_metrics) == 0:
|
||||||
|
@ -1,49 +0,0 @@
|
|||||||
# This local pyproject file is part of the migration from yapf to ruff format.
|
|
||||||
# It uses the same core rules as the main pyproject.toml file, but with the
|
|
||||||
# following differences:
|
|
||||||
# - ruff line length is overridden to 88
|
|
||||||
# - deprecated typing ignores (UP006, UP035) have been removed
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.ruff.lint.per-file-ignores]
|
|
||||||
"vllm/third_party/**" = ["ALL"]
|
|
||||||
"vllm/version.py" = ["F401"]
|
|
||||||
"vllm/_version.py" = ["ALL"]
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
# pycodestyle
|
|
||||||
"E",
|
|
||||||
# Pyflakes
|
|
||||||
"F",
|
|
||||||
# pyupgrade
|
|
||||||
"UP",
|
|
||||||
# flake8-bugbear
|
|
||||||
"B",
|
|
||||||
# flake8-simplify
|
|
||||||
"SIM",
|
|
||||||
# isort
|
|
||||||
"I",
|
|
||||||
# flake8-logging-format
|
|
||||||
"G",
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
# star imports
|
|
||||||
"F405", "F403",
|
|
||||||
# lambda expression assignment
|
|
||||||
"E731",
|
|
||||||
# Loop control variable not used within loop body
|
|
||||||
"B007",
|
|
||||||
# f-string format
|
|
||||||
"UP032",
|
|
||||||
# Can remove once 3.10+ is the minimum Python version
|
|
||||||
"UP007",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.lint.isort]
|
|
||||||
known-first-party = ["vllm"]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
|
||||||
docstring-code-format = true
|
|
@ -213,6 +213,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
|||||||
endif()
|
endif()
|
||||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||||
|
add_compile_definitions(VLLM_USE_ACL)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||||
@ -226,7 +227,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
|||||||
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
|
||||||
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
|
||||||
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
|
||||||
set(ONEDNN_VERBOSE "OFF")
|
set(ONEDNN_VERBOSE "ON")
|
||||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||||
|
|
||||||
FetchContent_MakeAvailable(oneDNN)
|
FetchContent_MakeAvailable(oneDNN)
|
||||||
|
@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
|
|||||||
else()
|
else()
|
||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
flashmla
|
flashmla
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
|
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||||
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
|
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
BUILD_COMMAND ""
|
BUILD_COMMAND ""
|
||||||
@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
|
|||||||
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
|
||||||
# Only build FlashMLA kernels if we are building for something compatible with
|
# Only build FlashMLA kernels if we are building for something compatible with
|
||||||
# sm90a
|
# sm90a
|
||||||
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
|
|
||||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
set(SUPPORT_ARCHS)
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
|
||||||
|
list(APPEND SUPPORT_ARCHS 9.0a)
|
||||||
|
endif()
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
|
||||||
|
list(APPEND SUPPORT_ARCHS 10.0a)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
|
||||||
|
if(FLASH_MLA_ARCHS)
|
||||||
|
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||||
|
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")
|
||||||
|
|
||||||
set(FlashMLA_SOURCES
|
set(FlashMLA_SOURCES
|
||||||
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
|
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
|
||||||
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
|
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
|
||||||
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
|
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
|
||||||
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
|
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
|
||||||
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
|
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
|
||||||
|
)
|
||||||
|
|
||||||
|
set(FlashMLA_Extension_SOURCES
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
||||||
|
)
|
||||||
|
|
||||||
set(FlashMLA_INCLUDES
|
set(FlashMLA_INCLUDES
|
||||||
|
${flashmla_SOURCE_DIR}/csrc
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||||
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||||
${flashmla_SOURCE_DIR}/csrc)
|
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||||
|
)
|
||||||
|
|
||||||
|
set(FlashMLA_Extension_INCLUDES
|
||||||
|
${flashmla_SOURCE_DIR}/csrc
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/sm90
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/cutlass/include
|
||||||
|
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
|
||||||
|
)
|
||||||
|
|
||||||
set_gencode_flags_for_srcs(
|
set_gencode_flags_for_srcs(
|
||||||
SRCS "${FlashMLA_SOURCES}"
|
SRCS "${FlashMLA_SOURCES}"
|
||||||
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||||
|
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${FlashMLA_Extension_SOURCES}"
|
||||||
|
CUDA_ARCHS "${FLASH_MLA_ARCHS}")
|
||||||
|
|
||||||
define_gpu_extension_target(
|
define_gpu_extension_target(
|
||||||
_flashmla_C
|
_flashmla_C
|
||||||
DESTINATION vllm
|
DESTINATION vllm
|
||||||
@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
|
|||||||
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
|
||||||
USE_SABI 3
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
|
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||||
|
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||||
|
target_compile_options(_flashmla_C PRIVATE
|
||||||
|
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||||
|
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||||
|
|
||||||
|
define_gpu_extension_target(
|
||||||
|
_flashmla_extension_C
|
||||||
|
DESTINATION vllm
|
||||||
|
LANGUAGE ${VLLM_GPU_LANG}
|
||||||
|
SOURCES ${FlashMLA_Extension_SOURCES}
|
||||||
|
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
|
||||||
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
|
||||||
|
USE_SABI 3
|
||||||
|
WITH_SOABI)
|
||||||
|
|
||||||
|
# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
|
||||||
|
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
|
||||||
|
target_compile_options(_flashmla_extension_C PRIVATE
|
||||||
|
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
|
||||||
|
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
|
||||||
else()
|
else()
|
||||||
# Create an empty target for setup.py when not targeting sm90a systems
|
# Create empty targets for setup.py when not targeting sm90a systems
|
||||||
add_custom_target(_flashmla_C)
|
add_custom_target(_flashmla_C)
|
||||||
|
add_custom_target(_flashmla_extension_C)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
97
cmake/external_projects/qutlass.cmake
Normal file
97
cmake/external_projects/qutlass.cmake
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
include(FetchContent)
|
||||||
|
|
||||||
|
set(CUTLASS_INCLUDE_DIR "${CUTLASS_INCLUDE_DIR}" CACHE PATH "Path to CUTLASS include/ directory")
|
||||||
|
|
||||||
|
if(DEFINED ENV{QUTLASS_SRC_DIR})
|
||||||
|
set(QUTLASS_SRC_DIR $ENV{QUTLASS_SRC_DIR})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(QUTLASS_SRC_DIR)
|
||||||
|
FetchContent_Declare(
|
||||||
|
qutlass
|
||||||
|
SOURCE_DIR ${QUTLASS_SRC_DIR}
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
FetchContent_Declare(
|
||||||
|
qutlass
|
||||||
|
GIT_REPOSITORY https://github.com/IST-DASLab/qutlass.git
|
||||||
|
GIT_TAG 830d2c4537c7396e14a02a46fbddd18b5d107c65
|
||||||
|
GIT_PROGRESS TRUE
|
||||||
|
CONFIGURE_COMMAND ""
|
||||||
|
BUILD_COMMAND ""
|
||||||
|
)
|
||||||
|
FetchContent_Populate(qutlass)
|
||||||
|
set(qutlass_SOURCE_DIR "${qutlass_SOURCE_DIR}")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT qutlass_SOURCE_DIR)
|
||||||
|
message(FATAL_ERROR "[QUTLASS] source directory could not be resolved.")
|
||||||
|
endif()
|
||||||
|
message(STATUS "[QUTLASS] QuTLASS is available at ${qutlass_SOURCE_DIR}")
|
||||||
|
|
||||||
|
cuda_archs_loose_intersection(QUTLASS_ARCHS "12.0a;10.0a" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND QUTLASS_ARCHS)
|
||||||
|
|
||||||
|
if(QUTLASS_ARCHS MATCHES "10\\.0a")
|
||||||
|
set(QUTLASS_TARGET_CC 100)
|
||||||
|
elseif(QUTLASS_ARCHS MATCHES "12\\.0a")
|
||||||
|
set(QUTLASS_TARGET_CC 120)
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "[QUTLASS] internal error parsing CUDA_ARCHS='${QUTLASS_ARCHS}'.")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(QUTLASS_SOURCES
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/bindings.cpp
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm.cu
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/gemm_ada.cu
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx.cu
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv.cu
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_mx_sm100.cu
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/fused_quantize_nv_sm100.cu
|
||||||
|
)
|
||||||
|
|
||||||
|
set(QUTLASS_INCLUDES
|
||||||
|
${qutlass_SOURCE_DIR}
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/include
|
||||||
|
${qutlass_SOURCE_DIR}/qutlass/csrc/include/cutlass_extensions
|
||||||
|
)
|
||||||
|
|
||||||
|
if(CUTLASS_INCLUDE_DIR AND EXISTS "${CUTLASS_INCLUDE_DIR}/cutlass/cutlass.h")
|
||||||
|
list(APPEND QUTLASS_INCLUDES "${CUTLASS_INCLUDE_DIR}")
|
||||||
|
elseif(EXISTS "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include/cutlass/cutlass.h")
|
||||||
|
list(APPEND QUTLASS_INCLUDES "${qutlass_SOURCE_DIR}/qutlass/third_party/cutlass/include")
|
||||||
|
message(STATUS "[QUTLASS] Using QuTLASS vendored CUTLASS headers (no vLLM CUTLASS detected).")
|
||||||
|
else()
|
||||||
|
message(FATAL_ERROR "[QUTLASS] CUTLASS headers not found. "
|
||||||
|
"Set -DCUTLASS_INCLUDE_DIR=/path/to/cutlass/include")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${QUTLASS_SOURCES}"
|
||||||
|
CUDA_ARCHS "${QUTLASS_ARCHS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
target_sources(_C PRIVATE ${QUTLASS_SOURCES})
|
||||||
|
target_include_directories(_C PRIVATE ${QUTLASS_INCLUDES})
|
||||||
|
target_compile_definitions(_C PRIVATE
|
||||||
|
QUTLASS_DISABLE_PYBIND=1
|
||||||
|
TARGET_CUDA_ARCH=${QUTLASS_TARGET_CC}
|
||||||
|
)
|
||||||
|
|
||||||
|
set_property(SOURCE ${QUTLASS_SOURCES} APPEND PROPERTY COMPILE_OPTIONS
|
||||||
|
$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr --use_fast_math -O3>
|
||||||
|
)
|
||||||
|
|
||||||
|
else()
|
||||||
|
if("${CMAKE_CUDA_COMPILER_VERSION}" VERSION_LESS "12.8")
|
||||||
|
message(STATUS
|
||||||
|
"[QUTLASS] Skipping build: CUDA 12.8 or newer is required (found ${CMAKE_CUDA_COMPILER_VERSION}).")
|
||||||
|
else()
|
||||||
|
message(STATUS
|
||||||
|
"[QUTLASS] Skipping build: no supported arch (12.0a / 10.0a) found in "
|
||||||
|
"CUDA_ARCHS='${CUDA_ARCHS}'.")
|
||||||
|
endif()
|
||||||
|
endif()
|
@ -38,7 +38,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a
|
GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
@ -16,7 +16,7 @@ import shutil
|
|||||||
|
|
||||||
from torch.utils.hipify.hipify_python import hipify
|
from torch.utils.hipify.hipify_python import hipify
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
# Project directory where all the source + include files live.
|
# Project directory where all the source + include files live.
|
||||||
@ -34,15 +34,14 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Source files to convert.
|
# Source files to convert.
|
||||||
parser.add_argument("sources",
|
parser.add_argument(
|
||||||
help="Source files to hipify.",
|
"sources", help="Source files to hipify.", nargs="*", default=[]
|
||||||
nargs="*",
|
)
|
||||||
default=[])
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Limit include scope to project_dir only
|
# Limit include scope to project_dir only
|
||||||
includes = [os.path.join(args.project_dir, '*')]
|
includes = [os.path.join(args.project_dir, "*")]
|
||||||
|
|
||||||
# Get absolute path for all source files.
|
# Get absolute path for all source files.
|
||||||
extra_files = [os.path.abspath(s) for s in args.sources]
|
extra_files = [os.path.abspath(s) for s in args.sources]
|
||||||
@ -51,25 +50,31 @@ if __name__ == '__main__':
|
|||||||
# The directory might already exist to hold object files so we ignore that.
|
# The directory might already exist to hold object files so we ignore that.
|
||||||
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
||||||
|
|
||||||
hipify_result = hipify(project_directory=args.project_dir,
|
hipify_result = hipify(
|
||||||
output_directory=args.output_dir,
|
project_directory=args.project_dir,
|
||||||
header_include_dirs=[],
|
output_directory=args.output_dir,
|
||||||
includes=includes,
|
header_include_dirs=[],
|
||||||
extra_files=extra_files,
|
includes=includes,
|
||||||
show_detailed=True,
|
extra_files=extra_files,
|
||||||
is_pytorch_extension=True,
|
show_detailed=True,
|
||||||
hipify_extra_files_only=True)
|
is_pytorch_extension=True,
|
||||||
|
hipify_extra_files_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
hipified_sources = []
|
hipified_sources = []
|
||||||
for source in args.sources:
|
for source in args.sources:
|
||||||
s_abs = os.path.abspath(source)
|
s_abs = os.path.abspath(source)
|
||||||
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
hipified_s_abs = (
|
||||||
(s_abs in hipify_result
|
hipify_result[s_abs].hipified_path
|
||||||
and hipify_result[s_abs].hipified_path is not None)
|
if (
|
||||||
else s_abs)
|
s_abs in hipify_result
|
||||||
|
and hipify_result[s_abs].hipified_path is not None
|
||||||
|
)
|
||||||
|
else s_abs
|
||||||
|
)
|
||||||
hipified_sources.append(hipified_s_abs)
|
hipified_sources.append(hipified_s_abs)
|
||||||
|
|
||||||
assert (len(hipified_sources) == len(args.sources))
|
assert len(hipified_sources) == len(args.sources)
|
||||||
|
|
||||||
# Print hipified source files.
|
# Print hipified source files.
|
||||||
print("\n".join(hipified_sources))
|
print("\n".join(hipified_sources))
|
||||||
|
@ -310,13 +310,13 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
|||||||
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
||||||
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
||||||
|
|
||||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
# If x.0a or x.0f is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
# remove x.0a or x.0f from SRC_CUDA_ARCHS and add x.0a or x.0f to _CUDA_ARCHS
|
||||||
set(_CUDA_ARCHS)
|
set(_CUDA_ARCHS)
|
||||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||||
if(_arch MATCHES "\\a$")
|
if(_arch MATCHES "[af]$")
|
||||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||||
string(REPLACE "a" "" _base "${_arch}")
|
string(REGEX REPLACE "[af]$" "" _base "${_arch}")
|
||||||
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
||||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||||
|
@ -28,10 +28,10 @@
|
|||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
#include "../quantization/fp8/amd/quant_utils.cuh"
|
#include "../quantization/w8a8/fp8/amd/quant_utils.cuh"
|
||||||
typedef __hip_bfloat16 __nv_bfloat16;
|
typedef __hip_bfloat16 __nv_bfloat16;
|
||||||
#else
|
#else
|
||||||
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
#include "../quantization/w8a8/fp8/nvidia/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_page_table(
|
load_page_table(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
params.mainloop,
|
params.mainloop,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_page_table, pipeline_pt_producer_state,
|
pipeline_page_table, pipeline_pt_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_cpasync(
|
load_cpasync(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
params.mainloop_params,
|
params.mainloop_params,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
local_split_kv,
|
local_split_kv,
|
||||||
/* must be shared pipe */
|
/* must be shared pipe */
|
||||||
pipeline_page_table, pipeline_pt_consumer_state
|
pipeline_page_table, pipeline_pt_consumer_state
|
||||||
);
|
);
|
||||||
@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_tma</* paged= */ true>(
|
load_tma</* paged= */ true>(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||||
}
|
}
|
||||||
@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_tma<false>(
|
load_tma<false>(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||||
}
|
}
|
||||||
@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
mma(blk_coord,
|
mma(blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
pipeline_mma_s, pipeline_mma_s_producer_state,
|
pipeline_mma_s, pipeline_mma_s_producer_state,
|
||||||
pipeline_p_mma, pipeline_p_mma_consumer_state,
|
pipeline_p_mma, pipeline_p_mma_consumer_state,
|
||||||
pipeline_mma_o, pipeline_mma_o_producer_state,
|
pipeline_mma_o, pipeline_mma_o_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto split_kv = params.split_kv;
|
auto split_kv = params.split_kv;
|
||||||
auto local_split_kv = split_kv;
|
auto local_split_kv = split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
compute(
|
compute(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
||||||
pipeline_p_mma, pipeline_p_mma_producer_state,
|
pipeline_p_mma, pipeline_p_mma_producer_state,
|
||||||
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
cutlass::arch::NamedBarrier(
|
cutlass::arch::NamedBarrier(
|
||||||
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
|
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
|
||||||
kNamedBarrierEpilogue
|
kNamedBarrierEpilogue
|
||||||
).arrive();
|
).arrive_and_wait();
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
16
csrc/cache.h
16
csrc/cache.h
@ -56,3 +56,19 @@ void cp_gather_cache(
|
|||||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||||
|
|
||||||
|
// Indexer K quantization and cache function
|
||||||
|
void indexer_k_quant_and_cache(
|
||||||
|
torch::Tensor& k, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||||
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
|
int64_t quant_block_size, // quantization block size
|
||||||
|
const std::string& scale_fmt);
|
||||||
|
|
||||||
|
// Extract function to gather quantized K cache
|
||||||
|
void cp_gather_indexer_k_quant_cache(
|
||||||
|
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||||
|
torch::Tensor& dst_k, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
|
||||||
|
const torch::Tensor& block_table, // [batch_size, num_blocks]
|
||||||
|
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
@ -9,15 +9,14 @@
|
|||||||
#include "quantization/vectorization_utils.cuh"
|
#include "quantization/vectorization_utils.cuh"
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include "quantization/fp8/amd/quant_utils.cuh"
|
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
|
||||||
#else
|
#else
|
||||||
#include "quantization/fp8/nvidia/quant_utils.cuh"
|
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <map>
|
#include <cfloat>
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_bf16.h>
|
#include <hip/hip_bf16.h>
|
||||||
@ -209,6 +208,20 @@ void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
|
// Used to copy/convert one element
|
||||||
|
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
|
||||||
|
struct CopyWithScaleOp {
|
||||||
|
float scale;
|
||||||
|
|
||||||
|
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
|
dst = static_cast<OutT>(src);
|
||||||
|
} else {
|
||||||
|
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
@ -224,59 +237,51 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
if (slot_idx < 0) {
|
if (slot_idx < 0) {
|
||||||
// Padding token that should be ignored.
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t block_idx = slot_idx / block_size;
|
const int64_t block_idx = slot_idx / block_size;
|
||||||
const int64_t block_offset = slot_idx % block_size;
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
|
const int h_block_count = head_size / x; // head_size//x
|
||||||
|
|
||||||
const int n = num_heads * head_size;
|
const int h_block_idx = threadIdx.x;
|
||||||
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
if (h_block_idx >= num_heads * h_block_count) {
|
||||||
const int64_t src_key_idx = token_idx * key_stride + i;
|
return;
|
||||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
}
|
||||||
|
|
||||||
const int head_idx = i / head_size;
|
const int head_idx = h_block_idx / h_block_count;
|
||||||
const int head_offset = i % head_size;
|
const int h_block = h_block_idx % h_block_count;
|
||||||
const int x_idx = head_offset / x;
|
|
||||||
const int x_offset = head_offset % x;
|
|
||||||
|
|
||||||
const int64_t tgt_key_idx =
|
const scalar_t* __restrict__ key_src =
|
||||||
block_idx * num_heads * (head_size / x) * block_size * x +
|
key + token_idx * key_stride + head_idx * head_size + h_block * x;
|
||||||
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
|
const int64_t src_value_start =
|
||||||
block_offset * x + x_offset;
|
token_idx * value_stride + head_idx * head_size + h_block * x;
|
||||||
const int64_t tgt_value_idx =
|
|
||||||
block_idx * num_heads * head_size * block_size +
|
cache_t* __restrict__ key_dst =
|
||||||
head_idx * head_size * block_size + head_offset * block_size +
|
key_cache + block_idx * num_heads * h_block_count * block_size * x +
|
||||||
block_offset;
|
head_idx * h_block_count * block_size * x + h_block * block_size * x +
|
||||||
scalar_t tgt_key = key[src_key_idx];
|
block_offset * x;
|
||||||
scalar_t tgt_value = value[src_value_idx];
|
const int64_t tgt_value_start =
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
block_idx * num_heads * h_block_count * x * block_size +
|
||||||
key_cache[tgt_key_idx] = tgt_key;
|
head_idx * h_block_count * x * block_size + h_block * x * block_size +
|
||||||
value_cache[tgt_value_idx] = tgt_value;
|
block_offset;
|
||||||
} else {
|
|
||||||
key_cache[tgt_key_idx] =
|
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
|
||||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
|
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
|
||||||
value_cache[tgt_value_idx] =
|
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
|
||||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
|
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
|
||||||
}
|
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
|
||||||
|
|
||||||
|
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, x, 0, 1, k_op);
|
||||||
|
|
||||||
|
const scalar_t* __restrict__ value_src = value + src_value_start;
|
||||||
|
cache_t* __restrict__ value_dst = value_cache + tgt_value_start;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < x; i++) {
|
||||||
|
v_op(value_dst[i * block_size], value_src[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Used by vectorization_utils to copy/convert one element
|
|
||||||
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
|
|
||||||
struct CopyWithScaleOp {
|
|
||||||
float scale;
|
|
||||||
|
|
||||||
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
|
|
||||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
|
||||||
dst = static_cast<OutT>(src);
|
|
||||||
} else {
|
|
||||||
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void reshape_and_cache_flash_kernel(
|
__global__ void reshape_and_cache_flash_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
@ -396,6 +401,241 @@ __global__ void concat_and_cache_mla_kernel(
|
|||||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
|
__global__ void concat_and_cache_ds_mla_kernel(
|
||||||
|
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||||
|
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||||
|
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||||
|
// + pe_dim)]
|
||||||
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
|
const int block_stride, //
|
||||||
|
const int entry_stride, //
|
||||||
|
const int kv_c_stride, //
|
||||||
|
const int k_pe_stride, //
|
||||||
|
const int kv_lora_rank, //
|
||||||
|
const int pe_dim, //
|
||||||
|
const int block_size, //
|
||||||
|
const float* scale //
|
||||||
|
) {
|
||||||
|
const int64_t token_idx = blockIdx.x;
|
||||||
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
|
// NOTE: slot_idx can be -1 if the token is padded
|
||||||
|
if (slot_idx < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t block_idx = slot_idx / block_size;
|
||||||
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
|
const int64_t dst_idx_start =
|
||||||
|
block_idx * block_stride + block_offset * entry_stride;
|
||||||
|
|
||||||
|
// For the NoPE part, each tile of 128 elements is handled by half of one warp
|
||||||
|
// (16 threads). There are 4 total tiles, so 2 warps (64 threads).
|
||||||
|
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
|
||||||
|
// The RoPE part (last 64 elements) is handled by another 1 warp (32 threads).
|
||||||
|
// So in total, we use 3 warps (96 threads) per block.
|
||||||
|
|
||||||
|
// Cast kv_cache to 16_bit for RoPE values
|
||||||
|
scalar_t* kv_cache_16bit =
|
||||||
|
reinterpret_cast<scalar_t*>(&kv_cache[dst_idx_start]);
|
||||||
|
|
||||||
|
// The last warp handles the RoPE part
|
||||||
|
if (threadIdx.x >= 64) {
|
||||||
|
// Each thread handles two elements of RoPE
|
||||||
|
const int8_t pe_idx_start = (threadIdx.x - 64) * 2;
|
||||||
|
const int64_t src_idx = token_idx * k_pe_stride + pe_idx_start;
|
||||||
|
// Vectorized load of two 16-bit values, performed as one 32-bit load
|
||||||
|
const int32_t vals = *reinterpret_cast<const int32_t*>(&k_pe[src_idx]);
|
||||||
|
// RoPE values start after the packed 8-bit NoPE values and the
|
||||||
|
// 32-bit scales
|
||||||
|
const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx_start;
|
||||||
|
// Vectorized store of two 16-bit values, performed as one 32-bit store
|
||||||
|
*reinterpret_cast<int32_t*>(&kv_cache_16bit[dst_idx]) = vals;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The first two warps handle the NoPE part
|
||||||
|
const int8_t warp_idx = threadIdx.x >> 5;
|
||||||
|
const int8_t lane_idx = threadIdx.x & 31;
|
||||||
|
const int8_t tile_idx = warp_idx * 2 + (lane_idx >> 4);
|
||||||
|
|
||||||
|
// Each thread handles 8 elements of NoPE
|
||||||
|
// Load the NoPE elements for this thread into registers
|
||||||
|
const int64_t src_idx_start = token_idx * kv_c_stride + (threadIdx.x * 8);
|
||||||
|
// Vectorized load of eight 16-bit values, performed as an int4 load
|
||||||
|
const int4 vals_i4 = *reinterpret_cast<const int4*>(&kv_c[src_idx_start]);
|
||||||
|
const scalar_t* vals = reinterpret_cast<const scalar_t*>(&vals_i4);
|
||||||
|
|
||||||
|
// Max absolute value of this thread's elements
|
||||||
|
float max_abs = fmaxf(fmaxf(fmaxf(fabsf(vals[0]), fabsf(vals[1])),
|
||||||
|
fmaxf(fabsf(vals[2]), fabsf(vals[3]))),
|
||||||
|
fmaxf(fmaxf(fabsf(vals[4]), fabsf(vals[5])),
|
||||||
|
fmaxf(fabsf(vals[6]), fabsf(vals[7]))));
|
||||||
|
|
||||||
|
// Warp-level reduction to find the max absolute value in each half-warp
|
||||||
|
#pragma unroll
|
||||||
|
for (int offset = 8; offset > 0; offset /= 2) {
|
||||||
|
max_abs = fmaxf(max_abs, VLLM_SHFL_XOR_SYNC_WIDTH(max_abs, offset, 16));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the scale for the tile
|
||||||
|
float tile_scale = max_abs / 448.f;
|
||||||
|
tile_scale = fmaxf(tile_scale, FLT_MIN);
|
||||||
|
|
||||||
|
// The first lane of each half-warp writes the scale to kv_cache
|
||||||
|
if ((lane_idx == 0) || (lane_idx == 16)) {
|
||||||
|
float* kv_cache_32bit = reinterpret_cast<float*>(&kv_cache[dst_idx_start]);
|
||||||
|
const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx;
|
||||||
|
kv_cache_32bit[dst_idx] = tile_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now all threads in the block scale and write their elements
|
||||||
|
// NoPE data is packed in the first kv_lora_rank/2 bytes (first 256 bytes)
|
||||||
|
const int64_t dst_idx_base = dst_idx_start + (threadIdx.x * 8);
|
||||||
|
|
||||||
|
uint8_t result[8];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
result[i] =
|
||||||
|
fp8::scaled_convert<uint8_t, scalar_t, Fp8KVCacheDataType::kFp8E4M3>(
|
||||||
|
vals[i], tile_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store as aligned 64-bit writes
|
||||||
|
*reinterpret_cast<uint64_t*>(&kv_cache[dst_idx_base]) =
|
||||||
|
*reinterpret_cast<const uint64_t*>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
|
__global__ void indexer_k_quant_and_cache_kernel(
|
||||||
|
const scalar_t* __restrict__ k, // [num_tokens, head_dim]
|
||||||
|
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride]
|
||||||
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
|
const int head_dim, // dimension of each head
|
||||||
|
const int quant_block_size, // quantization block size
|
||||||
|
const int cache_block_size, // cache block size
|
||||||
|
const int cache_stride, // stride for each token in kv_cache
|
||||||
|
const bool use_ue8m0 // use ue8m0 scale format
|
||||||
|
) {
|
||||||
|
constexpr int VEC_SIZE = 4;
|
||||||
|
const int64_t token_idx = blockIdx.x;
|
||||||
|
const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x +
|
||||||
|
threadIdx.y * blockDim.x + threadIdx.x) *
|
||||||
|
VEC_SIZE;
|
||||||
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
|
const int64_t block_idx = slot_idx / cache_block_size;
|
||||||
|
const int64_t block_offset = slot_idx % cache_block_size;
|
||||||
|
|
||||||
|
// NOTE: slot_idx can be -1 if the token is padded
|
||||||
|
if (slot_idx < 0 || (head_dim_idx >= head_dim)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
float2 k_val = (reinterpret_cast<const float2*>(
|
||||||
|
k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE];
|
||||||
|
scalar_t* k_val_ptr = reinterpret_cast<scalar_t*>(&k_val);
|
||||||
|
float amax = 0.0f;
|
||||||
|
for (int i = 0; i < VEC_SIZE; i++) {
|
||||||
|
amax = fmaxf(amax, fabsf(float(k_val_ptr[i])));
|
||||||
|
}
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
__syncwarp();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Reduced amax
|
||||||
|
for (int mask = 16; mask > 0; mask /= 2) {
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
amax = fmaxf(amax, __shfl_xor_sync(uint64_t(-1), amax, mask));
|
||||||
|
#else
|
||||||
|
amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
__syncwarp();
|
||||||
|
#endif
|
||||||
|
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||||
|
if (use_ue8m0) {
|
||||||
|
scale = exp2f(ceilf(log2f(scale)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t dst_offset = block_idx * cache_block_size * cache_stride +
|
||||||
|
block_offset * head_dim + head_dim_idx;
|
||||||
|
for (int i = 0; i < VEC_SIZE; i++) {
|
||||||
|
kv_cache[dst_offset + i] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(k_val_ptr[i], scale);
|
||||||
|
}
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
const int64_t dst_scale_idx =
|
||||||
|
block_idx * cache_block_size * cache_stride +
|
||||||
|
cache_block_size * head_dim +
|
||||||
|
(block_offset * head_dim + head_dim_idx) * 4 / quant_block_size;
|
||||||
|
reinterpret_cast<float*>(kv_cache)[dst_scale_idx / 4] = scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int BLOCK_Y_SIZE>
|
||||||
|
__global__ void cp_gather_indexer_k_quant_cache_kernel(
|
||||||
|
const char* __restrict__ kv_cache, // [num_blocks, block_size,
|
||||||
|
// cache_stride]
|
||||||
|
char* __restrict__ dst_k, // [num_tokens, head_dim]
|
||||||
|
char* __restrict__ dst_scale, // [num_tokens, head_dim / quant_block_size *
|
||||||
|
// 4]
|
||||||
|
const int* __restrict__ block_table, // [batch_size, num_blocks]
|
||||||
|
const int* __restrict__ cu_seq_lens, // [batch_size + 1]
|
||||||
|
const int batch_size, // batch size
|
||||||
|
const int64_t token_stride, // stride for each token in dst_k
|
||||||
|
const int64_t head_dim, // dimension of each head
|
||||||
|
const int64_t block_stride, // stride for each block in kv_cache
|
||||||
|
const int64_t cache_token_stride, // stride for each token in kv_cache
|
||||||
|
const int64_t cache_block_size, // num_tokens for each block in kv_cache
|
||||||
|
const int num_blocks, // number of blocks
|
||||||
|
const int num_tokens, // number of tokens
|
||||||
|
const int quant_block_size // quantization block size
|
||||||
|
) {
|
||||||
|
constexpr int VEC_SIZE = sizeof(float4) / sizeof(char);
|
||||||
|
const int token_idx = blockIdx.x * blockDim.y + threadIdx.y;
|
||||||
|
const int head_idx = (blockIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE;
|
||||||
|
// Find batch index within a block
|
||||||
|
__shared__ int batch_idx[BLOCK_Y_SIZE];
|
||||||
|
for (int iter = 0; iter < cuda_utils::ceil_div(batch_size, int(blockDim.x));
|
||||||
|
iter++) {
|
||||||
|
int tid = iter * blockDim.x + threadIdx.x;
|
||||||
|
if (tid < batch_size) {
|
||||||
|
const int seq_start = cu_seq_lens[tid];
|
||||||
|
const int seq_end = cu_seq_lens[tid + 1];
|
||||||
|
if (token_idx >= seq_start && token_idx < seq_end) {
|
||||||
|
batch_idx[threadIdx.y] = tid;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
__syncwarp();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (head_idx >= head_dim || token_idx >= num_tokens) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int inbatch_seq_idx = token_idx - cu_seq_lens[batch_idx[threadIdx.y]];
|
||||||
|
const int block_idx = block_table[batch_idx[threadIdx.y] * num_blocks +
|
||||||
|
inbatch_seq_idx / cache_block_size];
|
||||||
|
const int64_t src_block_offset = block_idx * block_stride;
|
||||||
|
const int64_t cache_inblock_offset =
|
||||||
|
(inbatch_seq_idx % cache_block_size) * head_dim + head_idx;
|
||||||
|
const int64_t src_inblock_offset = src_block_offset + cache_inblock_offset;
|
||||||
|
const int64_t dst_inblock_offset = token_idx * token_stride + head_idx;
|
||||||
|
|
||||||
|
reinterpret_cast<float4*>(dst_k)[dst_inblock_offset / VEC_SIZE] =
|
||||||
|
reinterpret_cast<const float4*>(kv_cache)[src_inblock_offset / VEC_SIZE];
|
||||||
|
;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
const int64_t src_scale_offset =
|
||||||
|
src_block_offset + cache_block_size * head_dim +
|
||||||
|
cache_inblock_offset * 4 / quant_block_size;
|
||||||
|
reinterpret_cast<float*>(dst_scale)[dst_inblock_offset / quant_block_size] =
|
||||||
|
reinterpret_cast<const float*>(kv_cache)[src_scale_offset / 4];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// KV_T is the data type of key and value tensors.
|
// KV_T is the data type of key and value tensors.
|
||||||
@ -431,14 +671,15 @@ void reshape_and_cache(
|
|||||||
|
|
||||||
int key_stride = key.stride(0);
|
int key_stride = key.stride(0);
|
||||||
int value_stride = value.stride(0);
|
int value_stride = value.stride(0);
|
||||||
|
int head_div_x = head_size / x;
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(num_heads * head_div_x, 512));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||||
CALL_RESHAPE_AND_CACHE)
|
CALL_RESHAPE_AND_CACHE);
|
||||||
}
|
}
|
||||||
|
|
||||||
// KV_T is the data type of key and value tensors.
|
// KV_T is the data type of key and value tensors.
|
||||||
@ -509,6 +750,18 @@ void reshape_and_cache_flash(
|
|||||||
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||||
|
|
||||||
|
// KV_T is the data type of key and value tensors.
|
||||||
|
// CACHE_T is the stored data type of kv-cache.
|
||||||
|
#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
|
vllm::concat_and_cache_ds_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||||
|
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||||
|
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
|
||||||
|
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||||
|
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||||
|
|
||||||
void concat_and_cache_mla(
|
void concat_and_cache_mla(
|
||||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||||
@ -531,20 +784,43 @@ void concat_and_cache_mla(
|
|||||||
int pe_dim = k_pe.size(1);
|
int pe_dim = k_pe.size(1);
|
||||||
int block_size = kv_cache.size(1);
|
int block_size = kv_cache.size(1);
|
||||||
|
|
||||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||||
|
TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla");
|
||||||
|
TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla");
|
||||||
|
TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(),
|
||||||
|
"kv_cache.size(2) must be 656 bytes for fp8_ds_mla");
|
||||||
|
TORCH_CHECK(kv_c.itemsize() == 2,
|
||||||
|
"kv_c.itemsize() must be 2 for fp8_ds_mla");
|
||||||
|
TORCH_CHECK(k_pe.itemsize() == 2,
|
||||||
|
"k_pe.itemsize() must be 2 for fp8_ds_mla");
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||||
|
}
|
||||||
|
|
||||||
int kv_c_stride = kv_c.stride(0);
|
int kv_c_stride = kv_c.stride(0);
|
||||||
int k_pe_stride = k_pe.stride(0);
|
int k_pe_stride = k_pe.stride(0);
|
||||||
int block_stride = kv_cache.stride(0);
|
int block_stride = kv_cache.stride(0);
|
||||||
int entry_stride = kv_cache.stride(1);
|
int entry_stride = kv_cache.stride(1);
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
|
||||||
dim3 block(std::min(kv_lora_rank, 512));
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
if (kv_cache_dtype == "fp8_ds_mla") {
|
||||||
CALL_CONCAT_AND_CACHE_MLA);
|
dim3 grid(num_tokens);
|
||||||
|
// For the NoPE part, each tile of 128 elements is handled by half of one
|
||||||
|
// warp (16 threads). There are 4 total tiles, so 2 warps (64 threads).
|
||||||
|
// Lanes 0 and 16 of each warp write the scale values for that warp's tiles.
|
||||||
|
// The RoPE part (last 64 elements) is handled by another 1 warp (32
|
||||||
|
// threads). So in total, we use 3 warps (96 threads) per block.
|
||||||
|
dim3 block(96);
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||||
|
CALL_CONCAT_AND_CACHE_DS_MLA);
|
||||||
|
} else {
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(kv_lora_rank, 512));
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||||
|
CALL_CONCAT_AND_CACHE_MLA);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@ -922,3 +1198,98 @@ void cp_gather_cache(
|
|||||||
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Macro to dispatch the kernel based on the data type.
|
||||||
|
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||||
|
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<KV_T*>(k.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||||
|
slot_mapping.data_ptr<int64_t>(), head_dim, quant_block_size, \
|
||||||
|
cache_block_size, cache_stride, use_ue8m0);
|
||||||
|
|
||||||
|
void indexer_k_quant_and_cache(
|
||||||
|
torch::Tensor& k, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||||
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
|
int64_t quant_block_size, // quantization block size
|
||||||
|
const std::string& scale_fmt) {
|
||||||
|
int num_tokens = k.size(0);
|
||||||
|
int head_dim = k.size(1);
|
||||||
|
int cache_block_size = kv_cache.size(1);
|
||||||
|
int cache_stride = kv_cache.size(2);
|
||||||
|
bool use_ue8m0 = scale_fmt == "ue8m0";
|
||||||
|
|
||||||
|
TORCH_CHECK(k.device() == kv_cache.device(),
|
||||||
|
"k and kv_cache must be on the same device");
|
||||||
|
TORCH_CHECK(k.device() == slot_mapping.device(),
|
||||||
|
"k and slot_mapping must be on the same device");
|
||||||
|
TORCH_CHECK(head_dim % quant_block_size == 0,
|
||||||
|
"head_dim must be divisible by quant_block_size");
|
||||||
|
|
||||||
|
constexpr int vec_size = 4;
|
||||||
|
dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) /
|
||||||
|
(quant_block_size * vec_size));
|
||||||
|
dim3 block(32, vec_size);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
|
||||||
|
CALL_INDEXER_K_QUANT_AND_CACHE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Macro to dispatch the kernel based on the data amount.
|
||||||
|
#define CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(BLOCK_Y_SIZE) \
|
||||||
|
vllm::cp_gather_indexer_k_quant_cache_kernel<BLOCK_Y_SIZE> \
|
||||||
|
<<<dim3((num_tokens + BLOCK_Y_SIZE - 1) / BLOCK_Y_SIZE, \
|
||||||
|
(head_dim + 8 * vec_size - 1) / (8 * vec_size)), \
|
||||||
|
dim3(8, BLOCK_Y_SIZE), 0, stream>>>( \
|
||||||
|
reinterpret_cast<char*>(kv_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<char*>(dst_k.data_ptr()), \
|
||||||
|
reinterpret_cast<char*>(dst_scale.data_ptr()), \
|
||||||
|
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
|
||||||
|
batch_size, dst_k.stride(0), dst_k.size(1), kv_cache.stride(0), \
|
||||||
|
kv_cache.stride(1), kv_cache.size(1), block_table.size(1), \
|
||||||
|
num_tokens, quant_block_size);
|
||||||
|
|
||||||
|
void cp_gather_indexer_k_quant_cache(
|
||||||
|
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
|
||||||
|
torch::Tensor& dst_k, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
|
||||||
|
const torch::Tensor& block_table, // [batch_size, num_blocks]
|
||||||
|
const torch::Tensor& cu_seq_lens // [batch_size + 1]
|
||||||
|
) {
|
||||||
|
int batch_size = block_table.size(0);
|
||||||
|
int num_tokens = dst_k.size(0);
|
||||||
|
int head_dim = dst_k.size(1);
|
||||||
|
int quant_block_size = head_dim * 4 / dst_scale.size(1);
|
||||||
|
|
||||||
|
TORCH_CHECK(kv_cache.device() == dst_k.device(),
|
||||||
|
"kv_cache and dst_k must be on the same device");
|
||||||
|
TORCH_CHECK(kv_cache.device() == dst_scale.device(),
|
||||||
|
"kv_cache and dst_scale must be on the same device");
|
||||||
|
TORCH_CHECK(kv_cache.device() == block_table.device(),
|
||||||
|
"kv_cache and block_table must be on the same device");
|
||||||
|
TORCH_CHECK(kv_cache.device() == cu_seq_lens.device(),
|
||||||
|
"kv_cache and cu_seq_lens must be on the same device");
|
||||||
|
TORCH_CHECK(head_dim % quant_block_size == 0,
|
||||||
|
"head_dim must be divisible by quant_block_size");
|
||||||
|
|
||||||
|
constexpr int vec_size = 16;
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_cache));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if (num_tokens < 32) {
|
||||||
|
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(1);
|
||||||
|
} else if (num_tokens < 64) {
|
||||||
|
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(2);
|
||||||
|
} else if (num_tokens < 128) {
|
||||||
|
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(4);
|
||||||
|
} else if (num_tokens < 256) {
|
||||||
|
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(8);
|
||||||
|
} else if (num_tokens < 512) {
|
||||||
|
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(16);
|
||||||
|
} else {
|
||||||
|
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
16
csrc/core/batch_invariant.hpp
Normal file
16
csrc/core/batch_invariant.hpp
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <string>
|
||||||
|
#include <cctype>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// vllm_kernel_override_batch_invariant(); returns true
|
||||||
|
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
||||||
|
inline bool vllm_kernel_override_batch_invariant() {
|
||||||
|
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
||||||
|
const char* val = std::getenv(env_key.c_str());
|
||||||
|
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
void DNNLMatMulPrimitiveHandler::prepack_weight(
|
||||||
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
|
void* original_b_ptr, dnnl::memory::desc original_b_md,
|
||||||
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
dnnl::memory::desc b_target_mem_desc) {
|
||||||
{b_k_stride_, b_n_stride_});
|
|
||||||
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
|
||||||
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
|
||||||
{
|
{
|
||||||
@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
|
|||||||
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
|
||||||
assert(!use_azp_);
|
assert(!use_azp_);
|
||||||
};
|
};
|
||||||
prepack_weight(args.b_ptr,
|
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||||
|
{b_k_stride_, b_n_stride_});
|
||||||
|
prepack_weight(args.b_ptr, original_b_md,
|
||||||
create_primitive_desc(
|
create_primitive_desc(
|
||||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||||
.use_bias = false,
|
.use_bias = false,
|
||||||
@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
|
|||||||
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
assert(ab_type_ == dnnl::memory::data_type::f32 ||
|
||||||
ab_type_ == dnnl::memory::data_type::bf16 ||
|
ab_type_ == dnnl::memory::data_type::bf16 ||
|
||||||
ab_type_ == dnnl::memory::data_type::f16);
|
ab_type_ == dnnl::memory::data_type::f16);
|
||||||
prepack_weight(args.b_ptr,
|
|
||||||
|
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
|
||||||
|
{b_k_stride_, b_n_stride_});
|
||||||
|
|
||||||
|
prepack_weight(args.b_ptr, original_b_md,
|
||||||
create_primitive_desc(
|
create_primitive_desc(
|
||||||
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
MSizeCacheKey{
|
||||||
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
|
#ifdef VLLM_USE_ACL
|
||||||
.use_bias = false,
|
// Arm Compute Library (ACL) backend for oneDNN does
|
||||||
.bias_type = dnnl::memory::data_type::undef},
|
// not support runtime
|
||||||
|
// dimensions, so we set M to a default value
|
||||||
|
.a_m_size = 128,
|
||||||
|
.a_m_stride = b_k_size_,
|
||||||
|
#else
|
||||||
|
.a_m_size = DNNL_RUNTIME_DIM_VAL,
|
||||||
|
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
|
||||||
|
#endif
|
||||||
|
.use_bias = false,
|
||||||
|
.bias_type = dnnl::memory::data_type::undef},
|
||||||
true)
|
true)
|
||||||
.weights_desc());
|
.weights_desc());
|
||||||
init_runtime_memory_cache(args);
|
init_runtime_memory_cache(args);
|
||||||
@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
|
|||||||
c_storage->set_data_handle((void*)args.c_ptr);
|
c_storage->set_data_handle((void*)args.c_ptr);
|
||||||
c_mem_desc->dims[0] = args.a_m_size;
|
c_mem_desc->dims[0] = args.a_m_size;
|
||||||
|
|
||||||
|
#ifndef VLLM_USE_ACL
|
||||||
|
// We do not support in ACL backend of oneDNN, we handle bias by:
|
||||||
|
// 1. copying it into the result tensor
|
||||||
|
// 2. attaching a fused-sum post-op to the matmul primitive
|
||||||
if (args.use_bias) {
|
if (args.use_bias) {
|
||||||
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
|
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
|
||||||
bias_storage->set_data_handle((void*)args.bias_ptr);
|
bias_storage->set_data_handle((void*)args.bias_ptr);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
dnnl::matmul matmul = get_matmul_cache(args);
|
dnnl::matmul matmul = get_matmul_cache(args);
|
||||||
|
|
||||||
|
// With ACL backend of oneDNN, the required memory format might change when the
|
||||||
|
// source tensor dims change. This does not really happen in practice, so isn't
|
||||||
|
// a performance hit, but we need to support it because the API allows for it.
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
auto new_expected_wei_desc =
|
||||||
|
dnnl::matmul::primitive_desc(
|
||||||
|
const_cast<dnnl_primitive_desc_t>(matmul.get_primitive_desc()))
|
||||||
|
.weights_desc();
|
||||||
|
if (new_expected_wei_desc != b_target_mem_desc_) {
|
||||||
|
prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(),
|
||||||
|
b_target_mem_desc_, new_expected_wei_desc);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
|
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
|
||||||
scratchpad_storage->set_data_handle(
|
scratchpad_storage->set_data_handle(
|
||||||
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
|
||||||
@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
|
|||||||
} else {
|
} else {
|
||||||
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
|
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
|
||||||
{key.a_m_stride, 1});
|
{key.a_m_stride, 1});
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
// ACL's backend of oneDNN always expects the weight format to be "any"
|
||||||
|
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
|
||||||
|
dnnl::memory::format_tag::any);
|
||||||
|
#else
|
||||||
b_md = b_target_mem_desc_;
|
b_md = b_target_mem_desc_;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
|
||||||
dnnl::memory::format_tag::ab);
|
dnnl::memory::format_tag::ab);
|
||||||
@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
|
|||||||
|
|
||||||
if (key.use_bias) {
|
if (key.use_bias) {
|
||||||
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
|
||||||
|
// Since ACL's matmuls don't support passing a bias_md, we apply the bias
|
||||||
|
// through a fused-sum post-op
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
dnnl::post_ops post_ops;
|
||||||
|
post_ops.append_sum();
|
||||||
|
attr.set_post_ops(post_ops);
|
||||||
|
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||||
|
attr);
|
||||||
|
#else
|
||||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
|
||||||
c_md, attr);
|
c_md, attr);
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
|
||||||
attr);
|
attr);
|
||||||
@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
|
|||||||
default_engine(), nullptr);
|
default_engine(), nullptr);
|
||||||
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
|
||||||
|
|
||||||
|
// ACL matmuls don't support bias_md, so we don't need these
|
||||||
|
#ifndef VLLM_USE_ACL
|
||||||
memory_cache_[DNNL_ARG_BIAS] =
|
memory_cache_[DNNL_ARG_BIAS] =
|
||||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||||
default_engine(), nullptr);
|
default_engine(), nullptr);
|
||||||
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
|
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
|
||||||
|
#endif
|
||||||
memory_cache_[DNNL_ARG_SCRATCHPAD] =
|
memory_cache_[DNNL_ARG_SCRATCHPAD] =
|
||||||
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
|
||||||
default_engine(), nullptr);
|
default_engine(), nullptr);
|
||||||
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
|
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_onednn_acl_supported() {
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
@ -101,7 +101,7 @@ class DNNLMatMulPrimitiveHandler {
|
|||||||
protected:
|
protected:
|
||||||
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
|
||||||
|
|
||||||
void prepack_weight(void* original_b_ptr,
|
void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md,
|
||||||
dnnl::memory::desc b_target_mem_desc);
|
dnnl::memory::desc b_target_mem_desc);
|
||||||
|
|
||||||
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
|
||||||
|
@ -527,21 +527,42 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major
|
|||||||
MatMulPrimitiveHandler* ptr =
|
MatMulPrimitiveHandler* ptr =
|
||||||
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
|
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
|
||||||
|
|
||||||
|
// ACL matmuls expect contiguous source tensors
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
torch::Tensor a_contig = a.contiguous();
|
||||||
|
#endif
|
||||||
|
|
||||||
MatMulPrimitiveHandler::ExecArgs exec_args;
|
MatMulPrimitiveHandler::ExecArgs exec_args;
|
||||||
|
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
exec_args.a_m_size = a_contig.size(0);
|
||||||
|
exec_args.a_m_stride = a_contig.stride(0);
|
||||||
|
#else
|
||||||
exec_args.a_m_size = a.size(0);
|
exec_args.a_m_size = a.size(0);
|
||||||
exec_args.a_m_stride = a.stride(0);
|
exec_args.a_m_stride = a.stride(0);
|
||||||
|
#endif
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
|
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
|
||||||
if (bias.has_value()) {
|
if (bias.has_value()) {
|
||||||
exec_args.use_bias = true;
|
exec_args.use_bias = true;
|
||||||
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
exec_args.bias_type = get_dnnl_type<scalar_t>();
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
// ACL matmuls in oneDNN do not support a bias.
|
||||||
|
// We handle a matmul with bias by doing: c = bias; c += matmul(a, b)
|
||||||
|
c.copy_(bias.value());
|
||||||
|
#else
|
||||||
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
exec_args.use_bias = false;
|
exec_args.use_bias = false;
|
||||||
exec_args.bias_type = get_dnnl_type<void>();
|
exec_args.bias_type = get_dnnl_type<void>();
|
||||||
exec_args.bias_ptr = nullptr;
|
exec_args.bias_ptr = nullptr;
|
||||||
}
|
}
|
||||||
|
#ifdef VLLM_USE_ACL
|
||||||
|
exec_args.a_ptr = a_contig.data_ptr<scalar_t>();
|
||||||
|
#else
|
||||||
exec_args.a_ptr = a.data_ptr<scalar_t>();
|
exec_args.a_ptr = a.data_ptr<scalar_t>();
|
||||||
|
|
||||||
|
#endif
|
||||||
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
exec_args.c_ptr = c.data_ptr<scalar_t>();
|
||||||
|
|
||||||
ptr->execute(exec_args);
|
ptr->execute(exec_args);
|
||||||
|
@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
|
|||||||
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
|
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
|
||||||
const std::optional<torch::Tensor>& bias, int64_t handler);
|
const std::optional<torch::Tensor>& bias, int64_t handler);
|
||||||
|
|
||||||
|
bool is_onednn_acl_supported();
|
||||||
|
|
||||||
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
|
||||||
torch::Tensor& kv_cache, double scale,
|
torch::Tensor& kv_cache, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
torch::Tensor& block_tables, torch::Tensor& seq_lens);
|
||||||
@ -181,6 +183,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"int handler) -> ()");
|
"int handler) -> ()");
|
||||||
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
|
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
|
||||||
|
|
||||||
|
// Check if oneDNN was built with ACL backend
|
||||||
|
ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);
|
||||||
|
|
||||||
// Create oneDNN W8A8 handler
|
// Create oneDNN W8A8 handler
|
||||||
ops.def(
|
ops.def(
|
||||||
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
|
||||||
|
@ -12,6 +12,7 @@ using CubMaxOp = cub::Max;
|
|||||||
#endif // CUB_VERSION
|
#endif // CUB_VERSION
|
||||||
#else
|
#else
|
||||||
#include <hipcub/hipcub.hpp>
|
#include <hipcub/hipcub.hpp>
|
||||||
using CubAddOp = cub::Sum;
|
namespace cub = hipcub;
|
||||||
using CubMaxOp = cub::Max;
|
using CubAddOp = hipcub::Sum;
|
||||||
|
using CubMaxOp = hipcub::Max;
|
||||||
#endif // USE_ROCM
|
#endif // USE_ROCM
|
||||||
|
@ -27,7 +27,7 @@ VLLMDataTypeNames: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "u4b8",
|
VLLMDataType.u4b8: "u4b8",
|
||||||
VLLMDataType.u8b128: "u8b128",
|
VLLMDataType.u8b128: "u8b128",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
@ -35,7 +35,7 @@ VLLMDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
|
||||||
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
||||||
@ -43,7 +43,7 @@ VLLMDataTypeSize: dict[Union[VLLMDataType, DataType], int] = {
|
|||||||
**{
|
**{
|
||||||
VLLMDataType.u4b8: 4,
|
VLLMDataType.u4b8: 4,
|
||||||
VLLMDataType.u8b128: 8,
|
VLLMDataType.u8b128: 8,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
VLLMDataTypeVLLMScalarTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
||||||
@ -67,15 +67,13 @@ VLLMDataTypeTorchDataTypeTag: dict[Union[VLLMDataType, DataType], str] = {
|
|||||||
DataType.f32: "at::ScalarType::Float",
|
DataType.f32: "at::ScalarType::Float",
|
||||||
}
|
}
|
||||||
|
|
||||||
VLLMKernelScheduleTag: dict[Union[
|
VLLMKernelScheduleTag: dict[
|
||||||
MixedInputKernelScheduleType, KernelScheduleType], str] = {
|
Union[MixedInputKernelScheduleType, KernelScheduleType], str
|
||||||
**KernelScheduleTag, # type: ignore
|
] = {
|
||||||
**{
|
**KernelScheduleTag, # type: ignore
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecialized:
|
**{
|
||||||
"cutlass::gemm::KernelTmaWarpSpecialized",
|
MixedInputKernelScheduleType.TmaWarpSpecialized: "cutlass::gemm::KernelTmaWarpSpecialized", # noqa: E501
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong:
|
MixedInputKernelScheduleType.TmaWarpSpecializedPingpong: "cutlass::gemm::KernelTmaWarpSpecializedPingpong", # noqa: E501
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedPingpong",
|
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative: "cutlass::gemm::KernelTmaWarpSpecializedCooperative", # noqa: E501
|
||||||
MixedInputKernelScheduleType.TmaWarpSpecializedCooperative:
|
},
|
||||||
"cutlass::gemm::KernelTmaWarpSpecializedCooperative",
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -8,11 +8,37 @@
|
|||||||
#define VLLM_LAUNCH_BLOCKS_CAP 4
|
#define VLLM_LAUNCH_BLOCKS_CAP 4
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// compile-time estimate of max threads per SM for launch bounds.
|
// Compile-time estimate of max threads per SM for launch bounds.
|
||||||
|
// Families: 1024, 1536, 2048 threads/SM.
|
||||||
#ifndef VLLM_MAX_THREADS_PER_SM
|
#ifndef VLLM_MAX_THREADS_PER_SM
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
|
#ifdef __CUDA_ARCH__
|
||||||
#define VLLM_MAX_THREADS_PER_SM 1536
|
|
||||||
|
/* 1024 thr/SM: Turing (sm_75) */
|
||||||
|
#if (__CUDA_ARCH__ == 750)
|
||||||
|
#define VLLM_MAX_THREADS_PER_SM 1024
|
||||||
|
|
||||||
|
/* 1536 thr/SM: Ampere GA10x (sm_86/87), Ada (sm_89),
|
||||||
|
GB20x consumer (sm_120/121), Thor (sm_101 or sm_110) */
|
||||||
|
#elif (__CUDA_ARCH__ == 860) || (__CUDA_ARCH__ == 870) || \
|
||||||
|
(__CUDA_ARCH__ == 890) || (__CUDA_ARCH__ == 1010) || \
|
||||||
|
(__CUDA_ARCH__ == 1100) || (__CUDA_ARCH__ == 1200) || \
|
||||||
|
(__CUDA_ARCH__ == 1210)
|
||||||
|
#define VLLM_MAX_THREADS_PER_SM 1536
|
||||||
|
|
||||||
|
/* 2048 thr/SM: Volta (sm_70/72), Ampere GA100 (sm_80),
|
||||||
|
Hopper (sm_90), Blackwell (sm_100/103) */
|
||||||
|
#elif (__CUDA_ARCH__ == 700) || (__CUDA_ARCH__ == 720) || \
|
||||||
|
(__CUDA_ARCH__ == 800) || (__CUDA_ARCH__ == 900) || \
|
||||||
|
(__CUDA_ARCH__ == 1000) || (__CUDA_ARCH__ == 1030)
|
||||||
|
#define VLLM_MAX_THREADS_PER_SM 2048
|
||||||
|
|
||||||
|
/* Fallback: use 2048 for unknown future CCs */
|
||||||
|
#else
|
||||||
|
#define VLLM_MAX_THREADS_PER_SM 2048
|
||||||
|
#endif
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
/* Host pass (no __CUDA_ARCH__): neutral default */
|
||||||
#define VLLM_MAX_THREADS_PER_SM 2048
|
#define VLLM_MAX_THREADS_PER_SM 2048
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "type_convert.cuh"
|
#include "type_convert.cuh"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
|
#include "core/batch_invariant.hpp"
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
@ -413,7 +414,9 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
|||||||
wt_ptr % req_alignment_bytes == 0;
|
wt_ptr % req_alignment_bytes == 0;
|
||||||
bool offsets_are_multiple_of_vector_width =
|
bool offsets_are_multiple_of_vector_width =
|
||||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||||
|
!batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
@ -459,7 +462,8 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
|||||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_POLY_NORM(8);
|
LAUNCH_FUSED_POLY_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_POLY_NORM(0);
|
LAUNCH_FUSED_POLY_NORM(0);
|
||||||
|
@ -6,9 +6,10 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include "type_convert.cuh"
|
#include "type_convert.cuh"
|
||||||
#include "quantization/fp8/common.cuh"
|
#include "quantization/w8a8/fp8/common.cuh"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
#include "cub_helpers.h"
|
#include "cub_helpers.h"
|
||||||
|
#include "core/batch_invariant.hpp"
|
||||||
|
|
||||||
#include <torch/cuda.h>
|
#include <torch/cuda.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
@ -240,7 +241,9 @@ void fused_add_rms_norm_static_fp8_quant(
|
|||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned =
|
bool ptrs_are_aligned =
|
||||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||||
|
!batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
|
@ -17,25 +17,30 @@ FILE_HEAD = """
|
|||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = (
|
||||||
"{{scalar_t}}, "
|
"template __global__ void Marlin<"
|
||||||
"{{w_type_id}}, "
|
"{{scalar_t}}, "
|
||||||
"{{s_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
"{{threads}}, "
|
"{{s_type_id}}, "
|
||||||
"{{thread_m_blocks}}, "
|
"{{threads}}, "
|
||||||
"{{thread_n_blocks}}, "
|
"{{thread_m_blocks}}, "
|
||||||
"{{thread_k_blocks}}, "
|
"{{thread_n_blocks}}, "
|
||||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
"{{thread_k_blocks}}, "
|
||||||
"{{stages}}, "
|
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||||
"{{group_blocks}}, "
|
"{{stages}}, "
|
||||||
"{{'true' if is_zp_float else 'false'}}>"
|
"{{group_blocks}}, "
|
||||||
"( MARLIN_KERNEL_PARAMS );")
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
|
"( MARLIN_KERNEL_PARAMS );"
|
||||||
|
)
|
||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = [
|
SCALAR_TYPES = [
|
||||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
"vllm::kU4",
|
||||||
"vllm::kFE2M1f"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
|
"vllm::kFE4M3fn",
|
||||||
|
"vllm::kFE2M1f",
|
||||||
]
|
]
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
|
||||||
|
|
||||||
@ -58,11 +63,12 @@ def generate_new_kernels():
|
|||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
|
|
||||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||||
|
):
|
||||||
# act order case only support gptq-int4 and gptq-int8
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
if group_blocks == 0 and scalar_type not in [
|
if group_blocks == 0 and scalar_type not in [
|
||||||
"vllm::kU4B8", "vllm::kU8B128"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
if thread_configs[2] == 256:
|
if thread_configs[2] == 256:
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include "../cuda_compat.h"
|
#include "../cuda_compat.h"
|
||||||
#include "../cub_helpers.h"
|
#include "../cub_helpers.h"
|
||||||
|
#include "../core/batch_invariant.hpp"
|
||||||
|
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
@ -405,7 +406,8 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
|
|||||||
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
|
||||||
static constexpr int VPT = Constants::VPT;
|
static constexpr int VPT = Constants::VPT;
|
||||||
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
|
||||||
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
const bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
||||||
|
const int num_warps = batch_invariant_launch ? 32 : (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
|
||||||
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
|
||||||
|
|
||||||
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
|
||||||
|
@ -100,6 +100,11 @@ void apply_repetition_penalties_(torch::Tensor& logits,
|
|||||||
const torch::Tensor& output_mask,
|
const torch::Tensor& output_mask,
|
||||||
const torch::Tensor& repetition_penalties);
|
const torch::Tensor& repetition_penalties);
|
||||||
|
|
||||||
|
void top_k_per_row(const torch::Tensor& logits, const torch::Tensor& rowStarts,
|
||||||
|
const torch::Tensor& rowEnds, torch::Tensor& indices,
|
||||||
|
torch::Tensor& values, int64_t numRows, int64_t stride0,
|
||||||
|
int64_t stride1);
|
||||||
|
|
||||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
torch::Tensor& weight, torch::Tensor& scale,
|
torch::Tensor& weight, torch::Tensor& scale,
|
||||||
double epsilon);
|
double epsilon);
|
||||||
@ -133,12 +138,12 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& out,
|
|||||||
torch::Tensor& input,
|
torch::Tensor& input,
|
||||||
torch::Tensor& input_global_scale);
|
torch::Tensor& input_global_scale);
|
||||||
#endif
|
#endif
|
||||||
void silu_mul_fp8_quant_deep_gemm_cuda(
|
void persistent_masked_m_silu_mul_quant(
|
||||||
const at::Tensor& input, // (E, T, 2*H)
|
const at::Tensor& input, // (E, T, 2*H)
|
||||||
const at::Tensor& counts, // (E)
|
const at::Tensor& counts, // (E)
|
||||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||||
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens);
|
bool use_ue8m0);
|
||||||
|
|
||||||
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include "../cuda_compat.h"
|
#include "../cuda_compat.h"
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
#include "quantization/fp8/common.cuh"
|
#include "quantization/w8a8/fp8/common.cuh"
|
||||||
|
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
|
||||||
@ -114,13 +114,22 @@ __global__ void act_and_mul_quant_kernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ float silu(float x) {
|
__device__ __forceinline__ float silu(float x) {
|
||||||
return (__fdividef(x, (1.f + expf(-x))));
|
return __fdividef(x, (1.f + expf(-x)));
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ __forceinline__ float2 silu2(float2 x) {
|
__device__ __forceinline__ float2 silu2(float2 x) {
|
||||||
return make_float2(silu(x.x), silu(x.y));
|
return make_float2(silu(x.x), silu(x.y));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ __nv_bfloat162 silu2_v2(float2 x) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
return make_bfloat162(__float2bfloat16_rn(silu(x.x)),
|
||||||
|
__float2bfloat16_rn(silu(x.y)));
|
||||||
|
#else
|
||||||
|
return __float22bfloat162_rn(make_float2(silu(x.x), silu(x.y)));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
__device__ __forceinline__ float warp_max(float v) {
|
__device__ __forceinline__ float warp_max(float v) {
|
||||||
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
static constexpr unsigned FULL_MASK = 0xffffffffu;
|
||||||
@ -223,224 +232,308 @@ constexpr __nv_bfloat16 get_fp8_min() {
|
|||||||
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
|
return __nv_bfloat16(__nv_bfloat16_raw{.x = 50032});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#ifndef USE_ROCM
|
|
||||||
template <typename fp8_type, int32_t NUM_WARPS, typename Idx_t,
|
template <typename Idx_t>
|
||||||
int NUM_PARALLEL_TOKENS, bool USE_UE8M0, int GROUP_SIZE = 128,
|
__device__ __forceinline__ int warp_expert_search(
|
||||||
|
int idx, int n, const Idx_t* __restrict__ input, Idx_t val) {
|
||||||
|
const Idx_t* input_ptr = input + idx;
|
||||||
|
int base_offset = 0;
|
||||||
|
|
||||||
|
for (;;) {
|
||||||
|
bool move_on = (idx < n && *input_ptr <= val);
|
||||||
|
|
||||||
|
unsigned mask = __ballot_sync(0xffffffff, move_on);
|
||||||
|
|
||||||
|
if (mask != 0xffffffffu) {
|
||||||
|
int last_lane = 31 - __clz(mask);
|
||||||
|
return base_offset + last_lane;
|
||||||
|
}
|
||||||
|
|
||||||
|
input_ptr += 32;
|
||||||
|
base_offset += 32;
|
||||||
|
idx += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int num_parallel_tokens>
|
||||||
|
__device__ __forceinline__ void token_bounds(int32_t n_tokens,
|
||||||
|
int32_t worker_id,
|
||||||
|
int32_t& n_tokens_lower,
|
||||||
|
int32_t& n_tokens_upper) {
|
||||||
|
if (n_tokens < num_parallel_tokens && worker_id < n_tokens) {
|
||||||
|
if (worker_id >= num_parallel_tokens) return;
|
||||||
|
n_tokens_lower = worker_id;
|
||||||
|
n_tokens_upper = worker_id + 1;
|
||||||
|
} else {
|
||||||
|
int32_t chunk_size = n_tokens / num_parallel_tokens;
|
||||||
|
int32_t residual = n_tokens - chunk_size * num_parallel_tokens;
|
||||||
|
auto calc_id = [&](int32_t id) {
|
||||||
|
if (id < residual)
|
||||||
|
return min(n_tokens, id * (chunk_size + 1));
|
||||||
|
else
|
||||||
|
return min(n_tokens, id * chunk_size + residual);
|
||||||
|
};
|
||||||
|
n_tokens_lower = calc_id(worker_id);
|
||||||
|
n_tokens_upper = calc_id(worker_id + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int BLOCK_COUNT, int SMEM_SIZE_BYTES_Y, typename fp8_type,
|
||||||
|
int THREADS, typename Idx_t, bool USE_UE8M0, int GROUP_SIZE = 128,
|
||||||
int NUM_STAGES = 3>
|
int NUM_STAGES = 3>
|
||||||
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
__global__ void silu_mul_fp8_quant_deep_gemm_kernel(
|
||||||
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
|
const __nv_bfloat16* __restrict__ _input, fp8_type* __restrict__ _y_q,
|
||||||
float* __restrict__ _y_s, const int32_t* __restrict__ counts,
|
float* __restrict__ _y_s, const int32_t* __restrict__ tokens_per_expert,
|
||||||
|
|
||||||
// sizes
|
// sizes
|
||||||
int H, int G,
|
Idx_t E, Idx_t T, Idx_t H,
|
||||||
|
|
||||||
// strides (in elements)
|
// strides (in elements)
|
||||||
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
|
Idx_t stride_i_e, Idx_t stride_i_t, Idx_t stride_i_h, Idx_t stride_yq_e,
|
||||||
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
|
Idx_t stride_yq_t, Idx_t stride_yq_h, Idx_t stride_ys_e, Idx_t stride_ys_t,
|
||||||
Idx_t stride_ys_g, Idx_t stride_counts_e) {
|
Idx_t stride_ys_g, Idx_t stride_counts_e) {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
static constexpr int NUM_WARPS = THREADS / WARP_SIZE;
|
||||||
|
|
||||||
|
static constexpr int LOAD_STAGE_SIZE = 2 * GROUP_SIZE / 8;
|
||||||
|
static constexpr int LOAD_STAGE_MOD = NUM_STAGES * LOAD_STAGE_SIZE;
|
||||||
|
|
||||||
|
static constexpr int COMPUTE_STAGE_SIZE = 2 * GROUP_SIZE / 4;
|
||||||
|
static constexpr int COMPUTE_STAGE_MOD = COMPUTE_STAGE_SIZE * NUM_STAGES;
|
||||||
|
|
||||||
|
extern __shared__ __align__(16) __int128_t smem_128[];
|
||||||
|
|
||||||
|
int* s_expert_offsets =
|
||||||
|
reinterpret_cast<int*>(smem_128 + (SMEM_SIZE_BYTES_Y / 16));
|
||||||
|
|
||||||
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
|
static constexpr __nv_bfloat16 fp8_min = get_fp8_min<fp8_type>();
|
||||||
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
|
static constexpr __nv_bfloat16 fp8_max = get_fp8_max<fp8_type>();
|
||||||
// We assign EPS with its 16-bit unsigned counterpart to allow constexpr.
|
// We assign EPS with it's 16-bit unsigned counterpart to allow constexpr.
|
||||||
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
|
static constexpr __nv_bfloat16 EPS = (__nv_bfloat16_raw{.x = 11996});
|
||||||
|
int tid = threadIdx.x;
|
||||||
|
int warp_id = tid >> 5;
|
||||||
|
int lane_id = tid & 0x1f;
|
||||||
|
|
||||||
// We pack 8 16-bit bfloat16 values into a 128-bit __int128_t.
|
int running_sum{};
|
||||||
static constexpr int32_t BFLOAT16_PER_GROUP = 8;
|
if (!warp_id) {
|
||||||
|
for (int i = 0; i < E; i += WARP_SIZE) {
|
||||||
|
bool valid = (i + threadIdx.x) < E;
|
||||||
|
int value =
|
||||||
|
(valid ? tokens_per_expert[i + threadIdx.x * stride_counts_e] : 0) +
|
||||||
|
(!lane_id ? running_sum : 0);
|
||||||
|
|
||||||
// We split the shared memory in half, corresponding to gate and up matrices:
|
for (int offset = 1; offset < 32; offset *= 2) {
|
||||||
// [...gate_i, ...up_i] where 0 <= i < stages.
|
int n = __shfl_up_sync(0xFFFFFFFFu, value, offset);
|
||||||
static constexpr int32_t S_NUM_128 =
|
if (lane_id >= offset) value += n;
|
||||||
2u * (GROUP_SIZE / BFLOAT16_PER_GROUP) * NUM_WARPS * NUM_STAGES;
|
}
|
||||||
static constexpr auto THREAD_COUNT = NUM_WARPS * WARP_SIZE;
|
|
||||||
static constexpr int HALF_THREAD_COUNT = THREAD_COUNT / 2;
|
|
||||||
static constexpr int32_t S_NUM_64 = S_NUM_128 * 2;
|
|
||||||
__shared__ __int128_t __align__(16) s_buff_128[S_NUM_128];
|
|
||||||
|
|
||||||
const int32_t tid = threadIdx.x;
|
if (valid) {
|
||||||
const int32_t warp_id = tid / WARP_SIZE;
|
s_expert_offsets[i + threadIdx.x + 1] = value;
|
||||||
const int32_t lane_id = tid % WARP_SIZE;
|
}
|
||||||
|
|
||||||
auto s_buff_compute_32 = reinterpret_cast<__nv_bfloat162*>(s_buff_128);
|
running_sum = __shfl_sync(0xFFFFFFFFu, value, WARP_SIZE - 1);
|
||||||
|
}
|
||||||
|
|
||||||
// block handles one (expert e, group g)
|
if (!lane_id) {
|
||||||
int32_t pid = blockIdx.x;
|
s_expert_offsets[0] = 0;
|
||||||
int32_t e = pid / G;
|
}
|
||||||
int32_t g = pid % G;
|
|
||||||
|
|
||||||
const int32_t n_tokens = counts[e * stride_counts_e];
|
|
||||||
|
|
||||||
if (!n_tokens) {
|
|
||||||
return; // Exit ASAP.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const Idx_t stride_i_t_128 = stride_i_t / 8u;
|
__syncthreads();
|
||||||
|
|
||||||
int32_t n_tokens_lower, n_tokens_upper;
|
int32_t total_tokens = s_expert_offsets[E];
|
||||||
|
|
||||||
|
const int warp_position_yq = warp_id * (H / NUM_WARPS);
|
||||||
|
const int warp_position_scales = warp_id * (H / (GROUP_SIZE * NUM_WARPS));
|
||||||
|
|
||||||
|
// A single block will handle tokens_per_block tokens.
|
||||||
// Each block i iterates over tokens of a slice of n_tokens =
|
// Each block i iterates over tokens of a slice of n_tokens =
|
||||||
// expert_counts[i], with the size of chunk being
|
// expert_counts[i], with the size of chunk being
|
||||||
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
|
// (n_tokens / NUM_PARALLEL_TOKENS) + residual, instead of
|
||||||
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
|
// updiv(n_tokens, NUM_PARALLEL_TOKENS) for better scheduling.
|
||||||
if (n_tokens < NUM_PARALLEL_TOKENS && blockIdx.y < n_tokens) {
|
|
||||||
// Specialize this, but can be likely fused.
|
|
||||||
if (blockIdx.y >= NUM_PARALLEL_TOKENS) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
n_tokens_lower = blockIdx.y;
|
|
||||||
n_tokens_upper = blockIdx.y + 1;
|
|
||||||
} else {
|
|
||||||
auto chunk_size = n_tokens / NUM_PARALLEL_TOKENS;
|
|
||||||
auto residual = n_tokens - chunk_size * NUM_PARALLEL_TOKENS;
|
|
||||||
auto calc_id = [&](int32_t id) {
|
|
||||||
if (id < residual) {
|
|
||||||
return min(n_tokens, id * (chunk_size + 1));
|
|
||||||
} else {
|
|
||||||
return min(n_tokens, id * chunk_size + residual);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
n_tokens_lower = calc_id(blockIdx.y);
|
|
||||||
n_tokens_upper = calc_id(blockIdx.y + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (n_tokens_lower >= n_tokens_upper) {
|
// Each warp will get space to store its hidden dim for gate and up.
|
||||||
|
__int128_t* s_hidden_load = smem_128 + warp_id * ((2 * 128 / 8) * NUM_STAGES);
|
||||||
|
__int128_t* smem_load_ptr = s_hidden_load + lane_id;
|
||||||
|
|
||||||
|
const __nv_bfloat16 fp8_inv = __hdiv(__float2bfloat16(1.f), fp8_max);
|
||||||
|
|
||||||
|
int32_t compute_pipeline_offset_64 = 0;
|
||||||
|
int32_t load_stage_offset{};
|
||||||
|
const __nv_bfloat16 one_bf16 = __float2bfloat16_rn(1.f);
|
||||||
|
|
||||||
|
__int64_t* smem_compute_ptr = reinterpret_cast<__int64_t*>(smem_128) +
|
||||||
|
warp_id * (2 * (GROUP_SIZE / 4) * NUM_STAGES) +
|
||||||
|
lane_id;
|
||||||
|
__int64_t* s_gate64_ptr = smem_compute_ptr;
|
||||||
|
__int64_t* s_up64_ptr = smem_compute_ptr + GROUP_SIZE / 4;
|
||||||
|
|
||||||
|
int tokens_lower, tokens_upper;
|
||||||
|
|
||||||
|
token_bounds<BLOCK_COUNT>(total_tokens, blockIdx.x, tokens_lower,
|
||||||
|
tokens_upper);
|
||||||
|
|
||||||
|
Idx_t expert_id{}, expert_offset{}, next_expert_offset{};
|
||||||
|
int token_id = tokens_lower;
|
||||||
|
int32_t t_load{};
|
||||||
|
|
||||||
|
if (token_id < tokens_upper) {
|
||||||
|
expert_id = warp_expert_search<int>(lane_id, E, s_expert_offsets, token_id);
|
||||||
|
expert_offset = s_expert_offsets[expert_id];
|
||||||
|
next_expert_offset = s_expert_offsets[expert_id + 1];
|
||||||
|
} else {
|
||||||
|
// This thread block has no work to do.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We do calculations here, using constexpr wherever possible.
|
int t_load_bound = H / (GROUP_SIZE * NUM_WARPS);
|
||||||
const Idx_t base_i = e * stride_i_e + NUM_WARPS * g * GROUP_SIZE * stride_i_h;
|
|
||||||
const Idx_t base_ys = e * stride_ys_e + NUM_WARPS * g * stride_ys_g;
|
|
||||||
const Idx_t base_yq =
|
|
||||||
e * stride_yq_e + NUM_WARPS * g * GROUP_SIZE * stride_yq_h;
|
|
||||||
Idx_t gate_off_128 = (base_i / static_cast<Idx_t>(8u));
|
|
||||||
auto input_128_ptr = reinterpret_cast<const __int128_t*>(_input);
|
|
||||||
auto gate_128_ptr = input_128_ptr + gate_off_128 + (tid % HALF_THREAD_COUNT) +
|
|
||||||
stride_i_t_128 * n_tokens_lower;
|
|
||||||
auto up_128_ptr = gate_128_ptr + (H * stride_i_h) / 8u;
|
|
||||||
auto y_s_ptr =
|
|
||||||
_y_s + base_ys + warp_id * stride_ys_g + n_tokens_lower * stride_ys_t;
|
|
||||||
auto y_q_ptr = _y_q + base_yq + warp_id * GROUP_SIZE +
|
|
||||||
stride_yq_t * n_tokens_lower + 4 * lane_id;
|
|
||||||
int32_t t_load = n_tokens_lower, load_stage_id = 0;
|
|
||||||
auto s_buff_gate_load_128 = s_buff_128 + (tid % HALF_THREAD_COUNT);
|
|
||||||
auto s_buff_up_load_128 = s_buff_gate_load_128 + S_NUM_128 / 2u;
|
|
||||||
int32_t stage_offset{};
|
|
||||||
|
|
||||||
static constexpr int32_t LOAD_STAGE_SIZE = (NUM_WARPS * WARP_SIZE / 2);
|
Idx_t base_i = ((expert_id * stride_i_e) / 8) +
|
||||||
static constexpr int32_t LOAD_STAGE_MOD =
|
(token_id - expert_offset) * stride_i_t / 8;
|
||||||
NUM_STAGES * (NUM_WARPS * WARP_SIZE / 2);
|
const Idx_t gate_warp_offset =
|
||||||
|
warp_id * ((stride_i_h * H) / (8 * NUM_WARPS)) + (lane_id & 0b1111);
|
||||||
|
|
||||||
|
const __int128_t* input_128_ptr =
|
||||||
|
reinterpret_cast<const __int128_t*>(_input) + gate_warp_offset +
|
||||||
|
((lane_id < 16) ? 0 : ((H * stride_i_h) / 8));
|
||||||
|
__int128_t* load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
|
||||||
|
|
||||||
|
auto token_offset = token_id - expert_offset;
|
||||||
|
|
||||||
// Two halves of all threads in a block conduct global loads for gate and up,
|
|
||||||
// repsectively.
|
|
||||||
auto load_and_advance_y_pred = [&] {
|
auto load_and_advance_y_pred = [&] {
|
||||||
if (t_load < n_tokens_upper) {
|
if (t_load < t_load_bound) {
|
||||||
auto s_gate_stage_128_staged_ptr = s_buff_gate_load_128 + stage_offset;
|
// Here we are simply continuing to load data
|
||||||
auto s_up_stage_128_staged_ptr = s_buff_up_load_128 + stage_offset;
|
// from the current token.
|
||||||
|
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
|
||||||
|
|
||||||
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
||||||
// unnecessary ALU ops.
|
// unnecessary ALU ops.
|
||||||
stage_offset += LOAD_STAGE_SIZE;
|
load_stage_offset += LOAD_STAGE_SIZE;
|
||||||
stage_offset %= LOAD_STAGE_MOD;
|
load_stage_offset %= LOAD_STAGE_MOD;
|
||||||
|
|
||||||
if (tid < HALF_THREAD_COUNT) {
|
cp_async4(smem_load_ptr_staged, load_ptr);
|
||||||
cp_async4(s_gate_stage_128_staged_ptr, gate_128_ptr);
|
load_ptr += GROUP_SIZE / 8;
|
||||||
gate_128_ptr += stride_i_t_128;
|
++t_load;
|
||||||
} else {
|
} else if (token_id + 1 < tokens_upper) {
|
||||||
cp_async4(s_up_stage_128_staged_ptr, up_128_ptr);
|
// We loaded everything from the current token, let's move on
|
||||||
up_128_ptr += stride_i_t_128;
|
// to the next one, and we checked that we have more tokens to load.
|
||||||
}
|
++token_id;
|
||||||
|
t_load = 0;
|
||||||
|
if (token_id >= next_expert_offset) {
|
||||||
|
// We need to find the next expert.
|
||||||
|
do {
|
||||||
|
// This is a loop because it's possible
|
||||||
|
// that some experts are assigned 0 tokens.
|
||||||
|
// NOTE: We are guaranteed that there's at least
|
||||||
|
// one more token left so we don't have to check for
|
||||||
|
// expert_id bounds.
|
||||||
|
++expert_id;
|
||||||
|
// This skips 1 memory read.
|
||||||
|
expert_offset = next_expert_offset;
|
||||||
|
next_expert_offset = s_expert_offsets[expert_id + 1];
|
||||||
|
} while (next_expert_offset == expert_offset);
|
||||||
|
|
||||||
|
base_i = expert_id * (stride_i_e / 8);
|
||||||
|
token_offset = 0;
|
||||||
|
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
|
||||||
|
} else {
|
||||||
|
// We remain within the same expert, so just
|
||||||
|
// move by H/4 __int128_t (2 * H/8).
|
||||||
|
base_i += stride_yq_t / 4;
|
||||||
|
token_offset++;
|
||||||
|
}
|
||||||
|
|
||||||
|
load_ptr = const_cast<__int128_t*>(input_128_ptr + base_i);
|
||||||
|
|
||||||
|
auto smem_load_ptr_staged = smem_load_ptr + load_stage_offset;
|
||||||
|
|
||||||
|
// It is very important that LOAD_STAGE_SIZE is constexpr to avoid
|
||||||
|
// unnecessary ALU ops.
|
||||||
|
load_stage_offset += LOAD_STAGE_SIZE;
|
||||||
|
load_stage_offset %= LOAD_STAGE_MOD;
|
||||||
|
|
||||||
|
cp_async4(smem_load_ptr_staged, load_ptr);
|
||||||
|
load_ptr += GROUP_SIZE / 8;
|
||||||
++t_load;
|
++t_load;
|
||||||
++load_stage_id;
|
|
||||||
}
|
}
|
||||||
// We fence even if there is nothing to load to simplify pipelining.
|
// We fence even if there is nothing to load to simplify pipelining.
|
||||||
cp_async_fence();
|
cp_async_fence();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We need to warm-up the pipeline.
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < NUM_STAGES - 1; i++) {
|
for (int i = 0; i < NUM_STAGES - 1; i++) {
|
||||||
load_and_advance_y_pred();
|
load_and_advance_y_pred();
|
||||||
}
|
}
|
||||||
|
|
||||||
__int64_t* s_gate_ptr = reinterpret_cast<__int64_t*>(
|
__nv_fp8x4_e4m3* y_q_base_ptr =
|
||||||
s_buff_compute_32 + warp_id * (GROUP_SIZE / 2)) +
|
reinterpret_cast<__nv_fp8x4_e4m3*>(_y_q) + lane_id;
|
||||||
lane_id;
|
auto y_scale_base_ptr = _y_s + warp_position_scales * stride_ys_g;
|
||||||
__int64_t* s_up_ptr = s_gate_ptr + S_NUM_64 / 2;
|
|
||||||
|
|
||||||
static constexpr int32_t STAGE_SIZE = (GROUP_SIZE * NUM_WARPS) / 4u;
|
for (auto j = tokens_lower; j < tokens_upper; j++) {
|
||||||
static constexpr int32_t STAGE_MOD = STAGE_SIZE * NUM_STAGES;
|
const Idx_t base_ys = expert_id * stride_ys_e;
|
||||||
|
auto y_s_ptr = y_scale_base_ptr + base_ys + token_offset * stride_ys_t;
|
||||||
|
__nv_fp8x4_e4m3* y_q_ptr =
|
||||||
|
y_q_base_ptr + (expert_id * stride_yq_e + token_offset * stride_yq_t +
|
||||||
|
warp_position_yq * stride_yq_h) /
|
||||||
|
4;
|
||||||
|
const int COMPUTE_LIMIT = H / (GROUP_SIZE * NUM_WARPS);
|
||||||
|
|
||||||
int32_t compute_pipeline_offset_64 = 0;
|
for (int i = 0; i < COMPUTE_LIMIT; i++) {
|
||||||
|
cp_async_wait<NUM_STAGES - 2>();
|
||||||
|
__syncthreads();
|
||||||
|
load_and_advance_y_pred();
|
||||||
|
|
||||||
for (int32_t t = n_tokens_lower; t < n_tokens_upper; ++t) {
|
__int64_t* gate64_ptr = s_gate64_ptr + compute_pipeline_offset_64;
|
||||||
__nv_bfloat162 results_bf162[2];
|
__int64_t* up64_ptr = s_up64_ptr + compute_pipeline_offset_64;
|
||||||
|
|
||||||
cp_async_wait<NUM_STAGES - 2>();
|
// COMPUTE_STAGE_SIZE/MOD must also be constexpr!
|
||||||
__syncthreads();
|
compute_pipeline_offset_64 += COMPUTE_STAGE_SIZE;
|
||||||
|
compute_pipeline_offset_64 %= COMPUTE_STAGE_MOD;
|
||||||
|
|
||||||
// We double-buffer pipelined loads so that the next load will
|
__int64_t gate64 = *gate64_ptr;
|
||||||
// concurrently run with compute without overwrites.
|
__int64_t up64 = *up64_ptr;
|
||||||
load_and_advance_y_pred();
|
|
||||||
|
|
||||||
auto s_gate_compute_64 = s_gate_ptr + compute_pipeline_offset_64;
|
// Compute
|
||||||
auto s_up_compute_64 = s_up_ptr + compute_pipeline_offset_64;
|
__nv_bfloat162 res[2];
|
||||||
|
__nv_bfloat162* s_up_comp = reinterpret_cast<__nv_bfloat162*>(&up64);
|
||||||
// STAGE_SIZE must also be constexpr!
|
__nv_bfloat162* s_gate_comp = reinterpret_cast<__nv_bfloat162*>(&gate64);
|
||||||
compute_pipeline_offset_64 += STAGE_SIZE;
|
|
||||||
compute_pipeline_offset_64 %= STAGE_MOD;
|
|
||||||
|
|
||||||
// Each thread loads (gate/up) 2X 4X bfloat16 values into registers.
|
|
||||||
__int64_t gate64 = *s_gate_compute_64;
|
|
||||||
__nv_bfloat162* s_gate_compute_32 =
|
|
||||||
reinterpret_cast<__nv_bfloat162*>(&gate64);
|
|
||||||
|
|
||||||
__int64_t up64 = *s_up_compute_64;
|
|
||||||
__nv_bfloat162* s_up_compute_32 = reinterpret_cast<__nv_bfloat162*>(&up64);
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 2; i++) {
|
for (int32_t k = 0; k < 2; ++k) {
|
||||||
// For silu, we make sure that div is emitted.
|
__nv_bfloat162 gate = silu2_v2(__bfloat1622float2(s_gate_comp[k]));
|
||||||
float2 gate = silu2(__bfloat1622float2(s_gate_compute_32[i]));
|
res[k] = __hmul2(gate, s_up_comp[k]);
|
||||||
results_bf162[i] = __float22bfloat162_rn(gate);
|
}
|
||||||
}
|
|
||||||
|
auto _y_max2 = __hmax2(__habs2(res[0]), __habs2(res[1]));
|
||||||
|
|
||||||
|
_y_max2.x = __hmax(__hmax(_y_max2.x, _y_max2.y), EPS);
|
||||||
|
|
||||||
|
__nv_bfloat16 y_s = __hmul(warp_max(_y_max2.x), fp8_inv);
|
||||||
|
|
||||||
|
if constexpr (USE_UE8M0) {
|
||||||
|
y_s = hexp2(hceil(hlog2(y_s)));
|
||||||
|
}
|
||||||
|
|
||||||
|
__nv_bfloat16 inv_y = __hdiv(one_bf16, y_s);
|
||||||
|
|
||||||
|
auto y_s2 = make_bfloat162(inv_y, inv_y);
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 2; i++) {
|
for (int32_t k = 0; k < 2; ++k) {
|
||||||
results_bf162[i] = __hmul2(results_bf162[i], s_up_compute_32[i]);
|
res[k] = clip(__hmul2(res[k], y_s2), __bfloat162bfloat162(fp8_min),
|
||||||
}
|
__bfloat162bfloat162(fp8_max));
|
||||||
|
}
|
||||||
|
|
||||||
auto _y_max2 =
|
*y_q_ptr = __nv_fp8x4_e4m3(res[0], res[1]);
|
||||||
__hmax2(__habs2(results_bf162[0]), __habs2(results_bf162[1]));
|
y_q_ptr += WARP_SIZE * stride_yq_h;
|
||||||
|
|
||||||
__nv_bfloat16 y_max_bf16 = __hmax(EPS, __hmax(_y_max2.x, _y_max2.y));
|
if (!lane_id) {
|
||||||
|
*y_s_ptr = y_s;
|
||||||
// An entire group is assigned to a single warp, so a simple warp reduce
|
y_s_ptr += stride_ys_g;
|
||||||
// is used.
|
}
|
||||||
__nv_bfloat16 y_s = warp_max(y_max_bf16) / fp8_max;
|
|
||||||
|
|
||||||
if constexpr (USE_UE8M0) {
|
|
||||||
y_s = hexp2(hceil(hlog2(y_s)));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto inv_y = __float2bfloat16_rn(1.f) / y_s;
|
|
||||||
|
|
||||||
auto y_s2 = make_bfloat162(inv_y, inv_y);
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int32_t i = 0; i < 2; ++i) {
|
|
||||||
results_bf162[i] =
|
|
||||||
clip(__hmul2(results_bf162[i], y_s2), __bfloat162bfloat162(fp8_min),
|
|
||||||
__bfloat162bfloat162(fp8_max));
|
|
||||||
}
|
|
||||||
|
|
||||||
auto fp8x4 = __nv_fp8x4_e4m3(results_bf162[0], results_bf162[1]);
|
|
||||||
*reinterpret_cast<__nv_fp8x4_e4m3*>(y_q_ptr) = fp8x4;
|
|
||||||
y_q_ptr += stride_yq_t;
|
|
||||||
|
|
||||||
if (lane_id == 0) {
|
|
||||||
*y_s_ptr = y_s;
|
|
||||||
y_s_ptr += stride_ys_t;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
@ -475,14 +568,14 @@ void silu_and_mul_quant(torch::Tensor& out, // [..., d]
|
|||||||
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
void silu_mul_fp8_quant_deep_gemm_cuda(
|
void persistent_masked_m_silu_mul_quant(
|
||||||
const at::Tensor& input, // (E, T, 2*H)
|
const at::Tensor& input, // (E, T, 2*H)
|
||||||
const at::Tensor& counts, // (E)
|
const at::Tensor& tokens_per_expert, // (E)
|
||||||
at::Tensor& y_q, // (E, T, H) [OUT]
|
at::Tensor& y_q, // (E, T, H) [OUT]
|
||||||
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
at::Tensor& y_s, // (E, T, H//group_size) [OUT]
|
||||||
int64_t group_size, bool use_ue8m0, int64_t num_parallel_tokens) {
|
bool use_ue8m0) {
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
// This kernel relies heavily on cp.async and fp8 support.
|
|
||||||
// This kernel currently only supports H % 128 == 0 and assumes a
|
// This kernel currently only supports H % 128 == 0 and assumes a
|
||||||
// fixed GROUP_SIZE of 128.
|
// fixed GROUP_SIZE of 128.
|
||||||
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
TORCH_CHECK(input.dtype() == torch::kBFloat16);
|
||||||
@ -491,10 +584,6 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
|
|||||||
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
TORCH_CHECK(y_s.dtype() == torch::kFloat32);
|
||||||
TORCH_CHECK(input.size(-1) % 256 == 0);
|
TORCH_CHECK(input.size(-1) % 256 == 0);
|
||||||
|
|
||||||
// Check that num_parallel_tokens is of power of 2 and between 1 and 64.
|
|
||||||
TORCH_CHECK(1 <= num_parallel_tokens && num_parallel_tokens <= 64);
|
|
||||||
TORCH_CHECK(!(num_parallel_tokens & (num_parallel_tokens - 1)));
|
|
||||||
|
|
||||||
using Idx_t = int64_t;
|
using Idx_t = int64_t;
|
||||||
|
|
||||||
Idx_t E = input.size(0);
|
Idx_t E = input.size(0);
|
||||||
@ -510,81 +599,54 @@ void silu_mul_fp8_quant_deep_gemm_cuda(
|
|||||||
Idx_t stride_ys_t = y_s.stride(1);
|
Idx_t stride_ys_t = y_s.stride(1);
|
||||||
Idx_t stride_ys_g = y_s.stride(2);
|
Idx_t stride_ys_g = y_s.stride(2);
|
||||||
|
|
||||||
Idx_t stride_counts_e = counts.stride(0);
|
Idx_t stride_counts_e = tokens_per_expert.stride(0);
|
||||||
|
|
||||||
static constexpr int GROUP_SIZE = 128;
|
static constexpr int GROUP_SIZE = 128;
|
||||||
|
|
||||||
#define KERNEL_FN \
|
|
||||||
if (use_ue8m0) { \
|
|
||||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
|
||||||
NUM_PARALLEL_TOKENS, true> \
|
|
||||||
<<<grid, block, 0, stream>>>( \
|
|
||||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
|
||||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
|
||||||
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
|
||||||
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
|
||||||
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
|
||||||
stride_counts_e); \
|
|
||||||
} else { \
|
|
||||||
vllm::silu_mul_fp8_quant_deep_gemm_kernel<fp8_t, NUM_WARPS, Idx_t, \
|
|
||||||
NUM_PARALLEL_TOKENS, false> \
|
|
||||||
<<<grid, block, 0, stream>>>( \
|
|
||||||
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
|
||||||
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
|
||||||
reinterpret_cast<int32_t*>(counts.data_ptr<int>()), H, G, \
|
|
||||||
stride_i_e, stride_i_t, stride_i_h, stride_yq_e, stride_yq_t, \
|
|
||||||
stride_yq_h, stride_ys_e, stride_ys_t, stride_ys_g, \
|
|
||||||
stride_counts_e); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define KERNEL_CALL_H \
|
|
||||||
if (H % (4 * GROUP_SIZE) == 0) { \
|
|
||||||
static constexpr int NUM_WARPS = 4; \
|
|
||||||
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
|
||||||
KERNEL_FN \
|
|
||||||
} else { \
|
|
||||||
static constexpr int NUM_WARPS = 1; \
|
|
||||||
populate_launch_params(NUM_WARPS, NUM_PARALLEL_TOKENS); \
|
|
||||||
KERNEL_FN \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define KERNEL_CALL_TOP_LEVEL \
|
|
||||||
if (num_parallel_tokens == 1) { \
|
|
||||||
static constexpr int NUM_PARALLEL_TOKENS = 1; \
|
|
||||||
KERNEL_CALL_H \
|
|
||||||
} else if (num_parallel_tokens == 2) { \
|
|
||||||
static constexpr int NUM_PARALLEL_TOKENS = 2; \
|
|
||||||
KERNEL_CALL_H \
|
|
||||||
} else if (num_parallel_tokens == 4) { \
|
|
||||||
static constexpr int NUM_PARALLEL_TOKENS = 4; \
|
|
||||||
KERNEL_CALL_H \
|
|
||||||
} else if (num_parallel_tokens == 8) { \
|
|
||||||
static constexpr int NUM_PARALLEL_TOKENS = 8; \
|
|
||||||
KERNEL_CALL_H \
|
|
||||||
} else if (num_parallel_tokens == 16) { \
|
|
||||||
static constexpr int NUM_PARALLEL_TOKENS = 16; \
|
|
||||||
KERNEL_CALL_H \
|
|
||||||
} else if (num_parallel_tokens == 32) { \
|
|
||||||
static constexpr int NUM_PARALLEL_TOKENS = 32; \
|
|
||||||
KERNEL_CALL_H \
|
|
||||||
} else if (num_parallel_tokens == 64) { \
|
|
||||||
static constexpr int NUM_PARALLEL_TOKENS = 64; \
|
|
||||||
KERNEL_CALL_H \
|
|
||||||
}
|
|
||||||
|
|
||||||
Idx_t G;
|
|
||||||
dim3 block, grid;
|
|
||||||
auto populate_launch_params = [&](int num_warps, int _num_parallel_tokens) {
|
|
||||||
G = H / Idx_t(group_size * num_warps);
|
|
||||||
grid = dim3(E * G, _num_parallel_tokens);
|
|
||||||
block = dim3(num_warps * WARP_SIZE);
|
|
||||||
};
|
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
||||||
VLLM_DISPATCH_FP8_TYPES(y_q.scalar_type(),
|
#define KERNEL(BLOCK_COUNT, USE_UE8M0, THREAD_COUNT, STAGES) \
|
||||||
"silu_mul_fp8_quant_deep_gemm_kernel",
|
static constexpr int NUM_WARPS = THREAD_COUNT / WARP_SIZE; \
|
||||||
[&] { KERNEL_CALL_TOP_LEVEL });
|
int sms = SILU_V2_BLOCK_COUNT; \
|
||||||
|
static constexpr int max_shared_mem_bytes = \
|
||||||
|
GROUP_SIZE * 2 * STAGES * NUM_WARPS * 2; \
|
||||||
|
dim3 grid(sms), block(THREAD_COUNT); \
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
|
||||||
|
VLLM_DISPATCH_FP8_TYPES( \
|
||||||
|
y_q.scalar_type(), "silu_mul_fp8_quant_deep_gemm_kernel", [&] { \
|
||||||
|
vllm::silu_mul_fp8_quant_deep_gemm_kernel< \
|
||||||
|
BLOCK_COUNT, max_shared_mem_bytes, fp8_t, THREAD_COUNT, Idx_t, \
|
||||||
|
USE_UE8M0, GROUP_SIZE, STAGES> \
|
||||||
|
<<<grid, block, max_shared_mem_bytes + (E + 1) * 16, stream>>>( \
|
||||||
|
reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), \
|
||||||
|
(fp8_t*)y_q.data_ptr(), y_s.data_ptr<float>(), \
|
||||||
|
reinterpret_cast<int32_t*>(tokens_per_expert.data_ptr()), E, \
|
||||||
|
T, H, stride_i_e, stride_i_t, stride_i_h, stride_yq_e, \
|
||||||
|
stride_yq_t, stride_yq_h, stride_ys_e, stride_ys_t, \
|
||||||
|
stride_ys_g, stride_counts_e); \
|
||||||
|
});
|
||||||
|
|
||||||
|
static constexpr int SILU_V2_BLOCK_COUNT = 132 * 32;
|
||||||
|
|
||||||
|
if (!use_ue8m0) {
|
||||||
|
if (H >= 4096) {
|
||||||
|
static constexpr int NUM_STAGES = 4;
|
||||||
|
static constexpr int THREAD_COUNT = 256;
|
||||||
|
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, NUM_STAGES);
|
||||||
|
} else {
|
||||||
|
static constexpr int THREAD_COUNT = 32;
|
||||||
|
KERNEL(SILU_V2_BLOCK_COUNT, false, THREAD_COUNT, 2);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (H >= 4096) {
|
||||||
|
static constexpr int NUM_STAGES = 4;
|
||||||
|
static constexpr int THREAD_COUNT = 256;
|
||||||
|
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, NUM_STAGES);
|
||||||
|
} else {
|
||||||
|
static constexpr int THREAD_COUNT = 32;
|
||||||
|
KERNEL(SILU_V2_BLOCK_COUNT, true, THREAD_COUNT, 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "core/registration.h"
|
||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
#include <cutlass/arch/arch.h>
|
#include <cutlass/arch/arch.h>
|
||||||
|
|
||||||
@ -418,3 +420,7 @@ void cutlass_fp4_group_mm(
|
|||||||
"12.8 or above.");
|
"12.8 or above.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||||
|
m.impl("cutlass_fp4_group_mm", &cutlass_fp4_group_mm);
|
||||||
|
}
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "quantization/vectorization.cuh"
|
||||||
// TODO(luka/varun):refactor common.cuh to use this file instead
|
// TODO(luka/varun):refactor common.cuh to use this file instead
|
||||||
#include "quantization/fp8/common.cuh"
|
#include "quantization/w8a8/fp8/common.cuh"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
|
@ -17,28 +17,32 @@ FILE_HEAD = """
|
|||||||
namespace MARLIN_NAMESPACE_NAME {
|
namespace MARLIN_NAMESPACE_NAME {
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
TEMPLATE = ("template __global__ void Marlin<"
|
TEMPLATE = (
|
||||||
"{{scalar_t}}, "
|
"template __global__ void Marlin<"
|
||||||
"{{w_type_id}}, "
|
"{{scalar_t}}, "
|
||||||
"{{s_type_id}}, "
|
"{{w_type_id}}, "
|
||||||
"{{threads}}, "
|
"{{s_type_id}}, "
|
||||||
"{{thread_m_blocks}}, "
|
"{{threads}}, "
|
||||||
"{{thread_n_blocks}}, "
|
"{{thread_m_blocks}}, "
|
||||||
"{{thread_k_blocks}}, "
|
"{{thread_n_blocks}}, "
|
||||||
"{{'true' if m_block_size_8 else 'false'}}, "
|
"{{thread_k_blocks}}, "
|
||||||
"{{stages}}, "
|
"{{'true' if m_block_size_8 else 'false'}}, "
|
||||||
"{{group_blocks}}, "
|
"{{stages}}, "
|
||||||
"{{'true' if is_zp_float else 'false'}}>"
|
"{{group_blocks}}, "
|
||||||
"( MARLIN_KERNEL_PARAMS );")
|
"{{'true' if is_zp_float else 'false'}}>"
|
||||||
|
"( MARLIN_KERNEL_PARAMS );"
|
||||||
|
)
|
||||||
|
|
||||||
# int8 with zero point case (vllm::kU8) is also supported,
|
# int8 with zero point case (vllm::kU8) is also supported,
|
||||||
# we don't add it to reduce wheel size.
|
# we don't add it to reduce wheel size.
|
||||||
SCALAR_TYPES = [
|
SCALAR_TYPES = [
|
||||||
"vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn",
|
"vllm::kU4",
|
||||||
"vllm::kFE2M1f"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
|
"vllm::kFE4M3fn",
|
||||||
|
"vllm::kFE2M1f",
|
||||||
]
|
]
|
||||||
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128),
|
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
|
||||||
(128, 64, 128)]
|
|
||||||
|
|
||||||
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
|
||||||
# group_blocks:
|
# group_blocks:
|
||||||
@ -59,11 +63,12 @@ def generate_new_kernels():
|
|||||||
all_template_str_list = []
|
all_template_str_list = []
|
||||||
|
|
||||||
for group_blocks, m_blocks, thread_configs in itertools.product(
|
for group_blocks, m_blocks, thread_configs in itertools.product(
|
||||||
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS):
|
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
|
||||||
|
):
|
||||||
# act order case only support gptq-int4 and gptq-int8
|
# act order case only support gptq-int4 and gptq-int8
|
||||||
if group_blocks == 0 and scalar_type not in [
|
if group_blocks == 0 and scalar_type not in [
|
||||||
"vllm::kU4B8", "vllm::kU8B128"
|
"vllm::kU4B8",
|
||||||
|
"vllm::kU8B128",
|
||||||
]:
|
]:
|
||||||
continue
|
continue
|
||||||
if thread_configs[2] == 256:
|
if thread_configs[2] == 256:
|
||||||
@ -93,8 +98,7 @@ def generate_new_kernels():
|
|||||||
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
|
||||||
|
|
||||||
is_zp_float_list = [False]
|
is_zp_float_list = [False]
|
||||||
if dtype == "fp16" and scalar_type == "vllm::kU4" and \
|
if dtype == "fp16" and scalar_type == "vllm::kU4" and group_blocks == 4:
|
||||||
group_blocks == 4:
|
|
||||||
# HQQ (is_zp_float = true) only supports
|
# HQQ (is_zp_float = true) only supports
|
||||||
# 4bit quantization and fp16
|
# 4bit quantization and fp16
|
||||||
is_zp_float_list.append(True)
|
is_zp_float_list.append(True)
|
||||||
|
@ -12,20 +12,21 @@ from functools import reduce
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
# yapf conflicts with isort for this block
|
from vllm_cutlass_library_extension import (
|
||||||
# yapf: disable
|
DataType,
|
||||||
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
|
EpilogueScheduleTag,
|
||||||
EpilogueScheduleType,
|
EpilogueScheduleType,
|
||||||
MixedInputKernelScheduleType,
|
MixedInputKernelScheduleType,
|
||||||
TileSchedulerTag,
|
TileSchedulerTag,
|
||||||
TileSchedulerType, VLLMDataType,
|
TileSchedulerType,
|
||||||
VLLMDataTypeNames,
|
VLLMDataType,
|
||||||
VLLMDataTypeSize, VLLMDataTypeTag,
|
VLLMDataTypeNames,
|
||||||
VLLMDataTypeTorchDataTypeTag,
|
VLLMDataTypeSize,
|
||||||
VLLMDataTypeVLLMScalarTypeTag,
|
VLLMDataTypeTag,
|
||||||
VLLMKernelScheduleTag)
|
VLLMDataTypeTorchDataTypeTag,
|
||||||
|
VLLMDataTypeVLLMScalarTypeTag,
|
||||||
# yapf: enable
|
VLLMKernelScheduleTag,
|
||||||
|
)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Generator templating
|
# Generator templating
|
||||||
@ -286,18 +287,23 @@ def generate_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|||||||
tile_shape = (
|
tile_shape = (
|
||||||
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
|
||||||
)
|
)
|
||||||
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
|
cluster_shape = (
|
||||||
f"x{schedule_config.cluster_shape_mnk[1]}" +
|
f"{schedule_config.cluster_shape_mnk[0]}"
|
||||||
f"x{schedule_config.cluster_shape_mnk[2]}")
|
+ f"x{schedule_config.cluster_shape_mnk[1]}"
|
||||||
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
|
+ f"x{schedule_config.cluster_shape_mnk[2]}"
|
||||||
.split("::")[-1]
|
)
|
||||||
epilogue_schedule = EpilogueScheduleTag[
|
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule].split(
|
||||||
schedule_config.epilogue_schedule].split("::")[-1]
|
"::"
|
||||||
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
|
)[-1]
|
||||||
.split("::")[-1]
|
epilogue_schedule = EpilogueScheduleTag[schedule_config.epilogue_schedule].split(
|
||||||
|
"::"
|
||||||
|
)[-1]
|
||||||
|
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler].split("::")[-1]
|
||||||
|
|
||||||
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
|
return (
|
||||||
f"_{epilogue_schedule}_{tile_scheduler}")
|
f"{tile_shape}_{cluster_shape}_{kernel_schedule}"
|
||||||
|
+ f"_{epilogue_schedule}_{tile_scheduler}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# mostly unique shorter sch_sig
|
# mostly unique shorter sch_sig
|
||||||
@ -316,18 +322,24 @@ def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str:
|
|||||||
|
|
||||||
# unique type_name
|
# unique type_name
|
||||||
def generate_type_signature(kernel_types: TypeConfig):
|
def generate_type_signature(kernel_types: TypeConfig):
|
||||||
return str("".join([
|
return str(
|
||||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
"".join(
|
||||||
for field in fields(TypeConfig)
|
[
|
||||||
]))
|
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
|
for field in fields(TypeConfig)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_type_option_name(kernel_types: TypeConfig):
|
def generate_type_option_name(kernel_types: TypeConfig):
|
||||||
return ", ".join([
|
return ", ".join(
|
||||||
f"{field.name.replace('b_', 'with_')+'_type'}=" +
|
[
|
||||||
VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
f"{field.name.replace('b_', 'with_') + '_type'}="
|
||||||
for field in fields(TypeConfig)
|
+ VLLMDataTypeNames[getattr(kernel_types, field.name)]
|
||||||
])
|
for field in fields(TypeConfig)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_power_of_two(n):
|
def is_power_of_two(n):
|
||||||
@ -335,7 +347,6 @@ def is_power_of_two(n):
|
|||||||
|
|
||||||
|
|
||||||
def to_cute_constant(value: list[int]):
|
def to_cute_constant(value: list[int]):
|
||||||
|
|
||||||
def _to_cute_constant(value: int):
|
def _to_cute_constant(value: int):
|
||||||
if is_power_of_two(value):
|
if is_power_of_two(value):
|
||||||
return f"_{value}"
|
return f"_{value}"
|
||||||
@ -350,11 +361,11 @@ def to_cute_constant(value: list[int]):
|
|||||||
|
|
||||||
def unique_schedules(impl_configs: list[ImplConfig]):
|
def unique_schedules(impl_configs: list[ImplConfig]):
|
||||||
# Use dict over set for deterministic ordering
|
# Use dict over set for deterministic ordering
|
||||||
return list({
|
return list(
|
||||||
sch: None
|
{
|
||||||
for impl_config in impl_configs
|
sch: None for impl_config in impl_configs for sch in impl_config.schedules
|
||||||
for sch in impl_config.schedules
|
}.keys()
|
||||||
}.keys())
|
)
|
||||||
|
|
||||||
|
|
||||||
def unsigned_type_with_bitwidth(num_bits):
|
def unsigned_type_with_bitwidth(num_bits):
|
||||||
@ -380,7 +391,7 @@ template_globals = {
|
|||||||
"gen_type_sig": generate_type_signature,
|
"gen_type_sig": generate_type_signature,
|
||||||
"unique_schedules": unique_schedules,
|
"unique_schedules": unique_schedules,
|
||||||
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
"unsigned_type_with_bitwidth": unsigned_type_with_bitwidth,
|
||||||
"gen_type_option_name": generate_type_option_name
|
"gen_type_option_name": generate_type_option_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -398,23 +409,28 @@ prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
|
|||||||
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
||||||
sources = []
|
sources = []
|
||||||
|
|
||||||
sources.append((
|
sources.append(
|
||||||
"machete_mm_dispatch",
|
(
|
||||||
mm_dispatch_template.render(impl_configs=impl_configs),
|
"machete_mm_dispatch",
|
||||||
))
|
mm_dispatch_template.render(impl_configs=impl_configs),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prepack_types = []
|
prepack_types = []
|
||||||
for impl_config in impl_configs:
|
for impl_config in impl_configs:
|
||||||
convert_type = impl_config.types.a \
|
convert_type = (
|
||||||
if impl_config.types.b_group_scale == DataType.void \
|
impl_config.types.a
|
||||||
else impl_config.types.b_group_scale
|
if impl_config.types.b_group_scale == DataType.void
|
||||||
|
else impl_config.types.b_group_scale
|
||||||
|
)
|
||||||
prepack_types.append(
|
prepack_types.append(
|
||||||
PrepackTypeConfig(
|
PrepackTypeConfig(
|
||||||
a=impl_config.types.a,
|
a=impl_config.types.a,
|
||||||
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
b_num_bits=VLLMDataTypeSize[impl_config.types.b],
|
||||||
convert=convert_type,
|
convert=convert_type,
|
||||||
accumulator=impl_config.types.accumulator,
|
accumulator=impl_config.types.accumulator,
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
def prepacked_type_key(prepack_type: PrepackTypeConfig):
|
||||||
# For now, we can just use the first accumulator type seen since
|
# For now, we can just use the first accumulator type seen since
|
||||||
@ -430,10 +446,14 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
|||||||
unique_prepack_types.append(prepack_type)
|
unique_prepack_types.append(prepack_type)
|
||||||
prepack_types_seen.add(key)
|
prepack_types_seen.add(key)
|
||||||
|
|
||||||
sources.append((
|
sources.append(
|
||||||
"machete_prepack",
|
(
|
||||||
prepack_dispatch_template.render(types=unique_prepack_types, ),
|
"machete_prepack",
|
||||||
))
|
prepack_dispatch_template.render(
|
||||||
|
types=unique_prepack_types,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Split up impls across files
|
# Split up impls across files
|
||||||
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0)
|
||||||
@ -466,10 +486,12 @@ def create_sources(impl_configs: list[ImplConfig], num_impl_files=8):
|
|||||||
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
curr_impl_in_file += len(files_impls[-1][-1].schedules)
|
||||||
|
|
||||||
for part, file_impls in enumerate(files_impls):
|
for part, file_impls in enumerate(files_impls):
|
||||||
sources.append((
|
sources.append(
|
||||||
f"machete_mm_impl_part{part+1}",
|
(
|
||||||
mm_impl_template.render(impl_configs=file_impls),
|
f"machete_mm_impl_part{part + 1}",
|
||||||
))
|
mm_impl_template.render(impl_configs=file_impls),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return sources
|
return sources
|
||||||
|
|
||||||
@ -514,8 +536,7 @@ def generate():
|
|||||||
# For now we use the same heuristic for all types
|
# For now we use the same heuristic for all types
|
||||||
# Heuristic is currently tuned for H100s
|
# Heuristic is currently tuned for H100s
|
||||||
default_heuristic = [
|
default_heuristic = [
|
||||||
(cond, ScheduleConfig(*tile_config,
|
(cond, ScheduleConfig(*tile_config, **sch_common_params)) # type: ignore
|
||||||
**sch_common_params)) # type: ignore
|
|
||||||
for cond, tile_config in default_tile_heuristic_config.items()
|
for cond, tile_config in default_tile_heuristic_config.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -541,14 +562,18 @@ def generate():
|
|||||||
a_token_scale=DataType.void,
|
a_token_scale=DataType.void,
|
||||||
out=a,
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
)
|
||||||
for a in (DataType.f16, DataType.bf16))
|
for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
|
||||||
|
for a in (DataType.f16, DataType.bf16)
|
||||||
|
)
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(GPTQ_kernel_type_configs,
|
for x in zip(
|
||||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
GPTQ_kernel_type_configs,
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
|
itertools.repeat(default_heuristic),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
AWQ_kernel_type_configs = list(
|
AWQ_kernel_type_configs = list(
|
||||||
@ -561,14 +586,18 @@ def generate():
|
|||||||
a_token_scale=DataType.void,
|
a_token_scale=DataType.void,
|
||||||
out=a,
|
out=a,
|
||||||
accumulator=DataType.f32,
|
accumulator=DataType.f32,
|
||||||
) for b in (DataType.u4, DataType.u8)
|
)
|
||||||
for a in (DataType.f16, DataType.bf16))
|
for b in (DataType.u4, DataType.u8)
|
||||||
|
for a in (DataType.f16, DataType.bf16)
|
||||||
|
)
|
||||||
|
|
||||||
impl_configs += [
|
impl_configs += [
|
||||||
ImplConfig(x[0], x[1], x[2])
|
ImplConfig(x[0], x[1], x[2])
|
||||||
for x in zip(AWQ_kernel_type_configs,
|
for x in zip(
|
||||||
itertools.repeat(get_unique_schedules(default_heuristic)),
|
AWQ_kernel_type_configs,
|
||||||
itertools.repeat(default_heuristic))
|
itertools.repeat(get_unique_schedules(default_heuristic)),
|
||||||
|
itertools.repeat(default_heuristic),
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO: Support W4A8 when ready
|
# TODO: Support W4A8 when ready
|
||||||
|
@ -231,7 +231,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
} else {
|
} else {
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
|
OutType, 1, TILE_N, TILE_K, Shape<_64, Int<TILE_N>, Int<TILE_K>>,
|
||||||
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
@ -245,7 +245,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
} else {
|
} else {
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
|
OutType, 1, TILE_N, TILE_K, Shape<_128, Int<TILE_N>, Int<TILE_K>>,
|
||||||
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
@ -259,7 +259,7 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
} else {
|
} else {
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
|
OutType, 1, TILE_N, TILE_K, Shape<_256, Int<TILE_N>, Int<TILE_K>>,
|
||||||
Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm,
|
Shape<_2, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm,
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
@ -271,10 +271,10 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
|
|||||||
// TMA epilogue isn't compatible with Swap A/B
|
// TMA epilogue isn't compatible with Swap A/B
|
||||||
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
|
||||||
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
|
OutType, TILE_M, 1, TILE_K, Shape<Int<TILE_M>, Int<TILE_N>, Int<TILE_K>>,
|
||||||
Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm,
|
Shape<_1, _1, _1>, cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm,
|
||||||
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
|
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>(
|
||||||
out, a, b, a_scales, b_scales);
|
out, a, b, a_scales, b_scales);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
@ -25,7 +25,10 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
|
||||||
int8_func(c, a, b, a_scales, b_scales, bias);
|
int8_func(c, a, b, a_scales, b_scales, bias);
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(false, "Int8 not supported for this architecture");
|
int32_t version_num = get_sm_version_num();
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "Int8 not supported on SM", version_num,
|
||||||
|
". Use FP8 quantization instead, or run on older arch (SM < 100).");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
@ -133,4 +133,4 @@ void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
@ -67,8 +67,9 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
|||||||
std::optional<torch::Tensor> const& bias);
|
std::optional<torch::Tensor> const& bias);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
|
#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \
|
||||||
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100
|
defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 || \
|
||||||
|
defined(ENABLE_SCALED_MM_SM120) && ENABLE_SCALED_MM_SM120
|
||||||
void get_cutlass_moe_mm_data_caller(
|
void get_cutlass_moe_mm_data_caller(
|
||||||
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
const torch::Tensor& topk_ids, torch::Tensor& expert_offsets,
|
||||||
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2,
|
||||||
@ -253,7 +254,7 @@ void cutlass_moe_mm(
|
|||||||
bool per_act_token, bool per_out_ch) {
|
bool per_act_token, bool per_out_ch) {
|
||||||
int32_t version_num = get_sm_version_num();
|
int32_t version_num = get_sm_version_num();
|
||||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||||
if (version_num >= 100) {
|
if (version_num >= 100 && version_num < 110) {
|
||||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
c_strides, per_act_token, per_out_ch);
|
c_strides, per_act_token, per_out_ch);
|
||||||
@ -261,7 +262,7 @@ void cutlass_moe_mm(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||||
if (version_num >= 90) {
|
if (version_num >= 90 && version_num < 100) {
|
||||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||||
c_strides, per_act_token, per_out_ch);
|
c_strides, per_act_token, per_out_ch);
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user