mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
273 Commits
v0.9.2rc1
...
nixl-upstr
Author | SHA1 | Date | |
---|---|---|---|
8a8b40d417 | |||
c3f7afa6a8 | |||
6cd8dec23f | |||
723263fa23 | |||
f29fd8a7f8 | |||
ed10f3cea1 | |||
b637e9dcb8 | |||
1e36c8687e | |||
5bac61362b | |||
313ae8c16a | |||
c847e34b39 | |||
e7e3e6d263 | |||
4ffd963fa0 | |||
56fe4bedd6 | |||
d91278181d | |||
20149d84d9 | |||
3534c39a20 | |||
c586b55667 | |||
33d560001e | |||
f148c44c6a | |||
235bfd5dfe | |||
68d28e37b0 | |||
37a7d5d74a | |||
d4d309409f | |||
85bd6599e4 | |||
91b3d190ae | |||
fc017915f5 | |||
9ad0a4588b | |||
016b8d1b7f | |||
80305c1b24 | |||
37e2ecace2 | |||
054c8657e3 | |||
d4170fad39 | |||
946aadb4a0 | |||
bcdfb2a330 | |||
ba8c300018 | |||
8cdc371217 | |||
61e20828da | |||
55e1c66da5 | |||
86f3ac21ce | |||
149f2435a5 | |||
c0569dbc82 | |||
8bb43b9c9e | |||
559756214b | |||
6d0cf239c6 | |||
3fc964433a | |||
0caf61c08a | |||
667624659b | |||
38efa28278 | |||
e8cc53af5e | |||
a4851cfe68 | |||
9887e8ec50 | |||
f326ab9c88 | |||
dcf2a5e208 | |||
1e9438e0b0 | |||
697ef765ee | |||
a99b9f7dee | |||
c488b928a7 | |||
2c7fa47161 | |||
88fc8a97e3 | |||
66f6fbd393 | |||
8632e831ba | |||
4bbfc36b16 | |||
80d38b8ac8 | |||
211b6a6113 | |||
247102f07f | |||
bd4c1e6fdb | |||
99b4f080d8 | |||
020f58abcd | |||
c1acd6d7d4 | |||
3b3b778d4a | |||
42d440c22b | |||
f45a332886 | |||
6e2c176e1f | |||
a86754a12b | |||
c2a2f19aba | |||
2c11a738b3 | |||
b639327ad9 | |||
4afe687a82 | |||
5de8d9f111 | |||
c1c8ca57ff | |||
a3a5a47e48 | |||
fb25e95688 | |||
0d4891cd03 | |||
f56d2996ca | |||
147afb448b | |||
3c7d942da8 | |||
890323dc1b | |||
01cae37713 | |||
11c0198615 | |||
b1235c3e10 | |||
44d02f54db | |||
a8593237c0 | |||
fc0f41d10a | |||
7b828e30d5 | |||
5f0af36af5 | |||
0d21b2664c | |||
9907fc4494 | |||
d47661f0cd | |||
53fa457391 | |||
6fb162447b | |||
66177189c5 | |||
b4f0b5f9aa | |||
cbd14ed561 | |||
7bd4c37ae7 | |||
8020e98c9f | |||
762be26a8e | |||
6a9e6b2abf | |||
5d09152ff1 | |||
31d5c1797f | |||
35514b682a | |||
e2de455c34 | |||
5b032352cc | |||
922f316441 | |||
5923ab9524 | |||
0cf893cae1 | |||
cf75cd2098 | |||
b854321ffe | |||
5b6fe23d05 | |||
f0c98cae27 | |||
574ad60db9 | |||
fdadb6f43a | |||
41060c6e08 | |||
3de2ed767f | |||
299252ea82 | |||
d6902ce79f | |||
5e53c89a74 | |||
c66e38ea4c | |||
251595368f | |||
4bed167768 | |||
b140416abf | |||
5b8366b61a | |||
c7753a9809 | |||
4b9a9435bb | |||
3482fd7e4e | |||
77f77a951e | |||
1a4f35e2ea | |||
be1e128dfb | |||
65393ee064 | |||
dc221ad72d | |||
7571a4a7e5 | |||
f67d986dd1 | |||
cc876d0f29 | |||
fdfd409f8f | |||
ffbcc9e757 | |||
59389c927b | |||
8f2720def9 | |||
ad6c2e1a0b | |||
49e8c7ea25 | |||
805d62ca88 | |||
b7d9e9416f | |||
7c12a765aa | |||
cd587c93ef | |||
332d4cb17b | |||
bf03ff3575 | |||
47043eb678 | |||
31b96d1c64 | |||
e59ba9e142 | |||
403b481573 | |||
138709f8d1 | |||
0bbac1c1b4 | |||
a3e4e85ece | |||
eb58f5953d | |||
4ac9c33f78 | |||
efe73d0575 | |||
853487bc1b | |||
9ff2af6d2b | |||
70ca5484f5 | |||
5358cce5ff | |||
2155e95ef1 | |||
f95570a52d | |||
b6e7e3d58f | |||
e760fcef22 | |||
6bbf1795b7 | |||
9e0ef888f0 | |||
97abeb1daa | |||
34dad19e7b | |||
6db31e7a27 | |||
977180c912 | |||
c40784c794 | |||
baed180aa0 | |||
0b407479ef | |||
5eaf570050 | |||
d8ee5a2ca4 | |||
b9fca83256 | |||
32dffc2772 | |||
c438183e99 | |||
baba0389f7 | |||
c6c22f16d3 | |||
dd382e0fe3 | |||
849590a2a7 | |||
a4c23314c0 | |||
b942c094e3 | |||
b4bab81660 | |||
b91cb3fa5c | |||
71d1d75b7a | |||
72d14d0eed | |||
e34d130c16 | |||
7721ef1786 | |||
8369b7c2a9 | |||
3eb4ad53f3 | |||
90a2769f20 | |||
e60d422f19 | |||
0d914c81a2 | |||
6e428cdd7a | |||
93b9d9f499 | |||
af107d5a0e | |||
31c5d0a1b7 | |||
afb7cff1b9 | |||
d2e841a10a | |||
14601f5fba | |||
042d131f39 | |||
8e807cdfa4 | |||
e601efcb10 | |||
22dd9c2730 | |||
a6d795d593 | |||
a37d75bbec | |||
edd270bc78 | |||
110df74332 | |||
1ad69e8375 | |||
b8a498c9b2 | |||
923147b5e8 | |||
45877ef740 | |||
6e4bef1bea | |||
4ff79a136e | |||
448acad31e | |||
eb0b2d2f08 | |||
3112271f6e | |||
1fd471e957 | |||
2c5ebec064 | |||
2e610deb72 | |||
6e2c19ce22 | |||
47db8c2c15 | |||
462b269280 | |||
c18b3b8e8b | |||
9528e3a05e | |||
9fb52e523a | |||
e202dd2736 | |||
43813e6361 | |||
cede942b87 | |||
fe1e924811 | |||
4548c03c50 | |||
40b86aa05e | |||
432870829d | |||
f73d02aadc | |||
c5ebe040ac | |||
8d763cb891 | |||
cf4cd53982 | |||
32c9be2200 | |||
8aeaa910a2 | |||
906e05d840 | |||
ef9a2990ae | |||
7e90870491 | |||
d3f05c9248 | |||
c108781c85 | |||
3d184b95b8 | |||
2f35a022e6 | |||
ffe00ef77a | |||
5561681d04 | |||
fbd62d8750 | |||
2e26f9156a | |||
9e5452ee34 | |||
0e3fe896e2 | |||
1caca5a589 | |||
783921d889 | |||
4a98edff1f | |||
a7bab0c9e5 | |||
25950dca9b | |||
a4113b035c | |||
7e1665b089 | |||
8d1096e7db | |||
8d775dd30a | |||
78fe77534b |
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
||||
done
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \
|
||||
--tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
|
||||
--batch_size "$BATCH_SIZE"
|
||||
|
@ -18,12 +18,14 @@ RTOL = 0.08
|
||||
|
||||
def launch_lm_eval(eval_config, tp_size):
|
||||
trust_remote_code = eval_config.get("trust_remote_code", False)
|
||||
max_model_len = eval_config.get("max_model_len", 4096)
|
||||
model_args = (
|
||||
f"pretrained={eval_config['model_name']},"
|
||||
f"tensor_parallel_size={tp_size},"
|
||||
f"enforce_eager=true,"
|
||||
f"add_bos_token=true,"
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
f"trust_remote_code={trust_remote_code},"
|
||||
f"max_model_len={max_model_len}"
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
|
@ -52,7 +52,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
|
@ -107,10 +107,9 @@ fi
|
||||
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/stest_attention_selector.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||
|
@ -48,10 +48,16 @@ function cpu_tests() {
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/language/generation -m cpu_model
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model
|
||||
# Note: disable until supports V1
|
||||
# pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
# pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
|
||||
# Note: disable Bart until supports V1
|
||||
pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
|
||||
pytest -v -s tests/models/language/pooling -m cpu_model
|
||||
pytest -v -s tests/models/multimodal/generation \
|
||||
--ignore=tests/models/multimodal/generation/test_mllama.py \
|
||||
@ -62,20 +68,14 @@ function cpu_tests() {
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]"
|
||||
|
||||
# Note: disable it until supports V1
|
||||
# Run AWQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
VLLM_USE_V1=0 pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v -k cpu_model \
|
||||
tests/basic_correctness/test_chunked_prefill.py"
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# VLLM_USE_V1=0 pytest -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# online serving
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
|
@ -11,8 +11,8 @@ container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head
|
||||
docker build -t ${image_name} -f docker/Dockerfile.xpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
docker image rm -f "${image_name}" || true;
|
||||
docker system prune -f || true;
|
||||
}
|
||||
@ -26,7 +26,9 @@ docker run \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
cd tests
|
||||
pytest -v -s v1/core
|
||||
'
|
||||
|
@ -22,16 +22,6 @@ trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# Build docker image.
|
||||
# TODO: build the image outside the script and share the image with other
|
||||
# tpu test if building time is too long.
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg max_jobs=16 \
|
||||
--build-arg USE_SCCACHE=1 \
|
||||
--build-arg GIT_REPO_CHECK=0 \
|
||||
--tag vllm/vllm-tpu-bm \
|
||||
--progress plain -f docker/Dockerfile.tpu .
|
||||
|
||||
LOG_ROOT=$(mktemp -d)
|
||||
# If mktemp fails, set -e will cause the script to exit.
|
||||
echo "Results will be stored in: $LOG_ROOT"
|
||||
|
@ -117,7 +117,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 40min
|
||||
- label: Entrypoints Test (LLM) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
@ -125,8 +125,6 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
@ -135,9 +133,21 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Test (API Server) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -282,7 +292,7 @@ steps:
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
@ -630,6 +640,18 @@ steps:
|
||||
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
|
||||
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
|
||||
|
||||
- label: Transformers Nightly Models Test
|
||||
working_dir: "/vllm-workspace/"
|
||||
optional: true
|
||||
commands:
|
||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||
- pytest -v -s tests/models/test_initialization.py
|
||||
- pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
- python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
|
6
.gemini/config.yaml
Normal file
6
.gemini/config.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
# https://developers.google.com/gemini-code-assist/docs/customize-gemini-behavior-github
|
||||
have_fun: false # Just review the code
|
||||
code_review:
|
||||
comment_severity_threshold: HIGH # Reduce quantity of comments
|
||||
pull_request_opened:
|
||||
summary: false # Don't summarize the PR in a separate comment
|
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@ -16,7 +16,8 @@
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm
|
||||
/vllm/entrypoints @aarnphm
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
/vllm/compilation @zou3519 @youkaichao
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
|
2
.github/mergify.yml
vendored
2
.github/mergify.yml
vendored
@ -86,8 +86,6 @@ pull_request_rules:
|
||||
- and:
|
||||
- files~=^vllm/model_executor/models/
|
||||
- files=vllm/model_executor/models/registry.py
|
||||
- files=tests/models/registry.py
|
||||
- files=docs/models/supported_models.md
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
2
.github/workflows/lint-and-deploy.yaml
vendored
2
.github/workflows/lint-and-deploy.yaml
vendored
@ -68,7 +68,7 @@ jobs:
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -146,6 +146,7 @@ venv.bak/
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples
|
||||
|
||||
# mypy
|
||||
|
@ -166,11 +166,11 @@ repos:
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py
|
||||
files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=<hook-id>."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
|
@ -171,7 +171,6 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
|
||||
@ -232,7 +231,6 @@ endif()
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
@ -259,7 +257,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -393,7 +391,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
@ -409,7 +407,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -424,7 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
@ -438,7 +436,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -453,7 +451,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
@ -468,7 +466,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -511,7 +509,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -520,7 +518,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
@ -532,7 +530,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# FP4 Archs and flags
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
@ -553,9 +551,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
@ -615,6 +614,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
@ -622,7 +641,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
@ -674,7 +693,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND MACHETE_ARCHS)
|
||||
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
|
@ -63,13 +63,11 @@ vLLM is fast with:
|
||||
- Speculative decoding
|
||||
- Chunked prefill
|
||||
|
||||
**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron
|
||||
|
@ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset):
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
input_high = int(real_input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to prevent
|
||||
# sampling 0 tokens, which can cause request failures.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
@ -701,6 +704,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self,
|
||||
dataset_path: str,
|
||||
dataset_split: str,
|
||||
no_stream: bool = False,
|
||||
dataset_subset: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -708,6 +712,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
self.load_stream = not no_stream
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
@ -716,7 +721,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
streaming=self.load_stream,
|
||||
)
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
|
||||
|
@ -825,6 +825,7 @@ def main(args: argparse.Namespace):
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
no_stream=args.no_stream,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1033,6 +1034,11 @@ def create_argument_parser():
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
|
@ -356,6 +356,7 @@ def get_requests(args, tokenizer):
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
common_kwargs["no_stream"] = args.no_stream
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
@ -610,6 +611,11 @@ def create_argument_parser():
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
|
||||
# Compute global scale for weight
|
||||
b_amax = torch.abs(b).max().to(torch.float32)
|
||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
return b_fp4, scale_b_fp4, b_global_scale
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
|
||||
|
||||
# Compute global scale for activation
|
||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||
a_amax = torch.abs(a).max().to(torch.float32)
|
||||
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
|
||||
# Alpha for the GEMM operation
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
def inner(*args):
|
||||
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||
return fn(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
||||
torch_per_token_quant_fp8 = torch.compile(
|
||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
||||
fullgraph=True,
|
||||
dynamic=False, # recompile for different shapes
|
||||
)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
||||
|
||||
|
||||
def cuda_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input)
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between Triton and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
||||
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
||||
|
||||
if torch.allclose(
|
||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda"],
|
||||
line_names=["Torch", "CUDA"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
@ -86,6 +86,9 @@ def benchmark_config(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_deep_gemm:
|
||||
# we use the default block shape for deepgemm
|
||||
block_quant_shape = [128, 128]
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
@ -620,7 +623,7 @@ def main(args: argparse.Namespace):
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
batch_sizes = args.batch_size
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
@ -728,7 +731,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--model-prefix", type=str, required=False)
|
||||
|
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Currently only HEAD_GRP_SIZE == 8 is supported
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
page_size,
|
||||
max_kv_len,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print("Running benchmark for kv_cache_dtype: bfloat16")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
@ -165,17 +165,32 @@ else()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
|
||||
#
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.1
|
||||
GIT_TAG v3.8.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
if(USE_ACL)
|
||||
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
|
||||
if(NOT ARM_COMPUTE_LIBRARY)
|
||||
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
@ -264,6 +279,11 @@ elseif(POWER10_FOUND)
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
|
@ -24,6 +24,7 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
@ -0,0 +1,372 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
// printf("set_split_kv start");
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
// std::cout << H << " " << K << " " << D << " " << B << "\n";
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
// printf(" sm_count = %d\n", sm_count);
|
||||
int max_splits = ceil_div(K, 128);
|
||||
max_splits = min(16, max_splits);
|
||||
// printf(" max_splits = %d\n", max_splits);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
// printf(" sms_per_batch = %d\n", sms_per_batch);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
// printf(" args.split_kv = %d\n", args.split_kv);
|
||||
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -0,0 +1,203 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,165 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
273
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
273
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
@ -0,0 +1,273 @@
|
||||
/*
|
||||
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/kernel_hardware_info.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
// clang-format off
|
||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||
}
|
||||
#else
|
||||
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape,
|
||||
Element,
|
||||
ElementAcc,
|
||||
ElementOut,
|
||||
ElementAcc,
|
||||
TileScheduler,
|
||||
/*kIsCpAsync=*/!IsPaged128>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
float scale = float(sm_scale);
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_nope = cute::make_tuple(
|
||||
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
||||
StrideQ stride_Q_pe = cute::make_tuple(
|
||||
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
||||
|
||||
StrideK stride_C = cute::make_tuple(
|
||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale,
|
||||
Q_nope_ptr,
|
||||
stride_Q_nope,
|
||||
Q_pe_ptr,
|
||||
stride_Q_pe,
|
||||
C_ptr,
|
||||
stride_C,
|
||||
C_ptr + D_latent,
|
||||
stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()),
|
||||
stride_PT,
|
||||
page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||
void runMla(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
at::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
|
||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
||||
// Maybe per batch split kv will fix this.
|
||||
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
||||
// which are float, so Element type here doesn't matter.
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
|
||||
|
||||
// Get split kv. Requires problem shape and sm_count only.
|
||||
typename MlaSm100Type::Fmha::Arguments arguments;
|
||||
using TileShapeH = typename MlaSm100Type::TileShapeH;
|
||||
using TileShapeD = typename MlaSm100Type::TileShapeD;
|
||||
arguments.problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
}
|
||||
|
||||
#endif
|
||||
// clang-format on
|
@ -18,12 +18,7 @@
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -187,7 +182,6 @@ void paged_attention_v1(
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -18,12 +18,7 @@
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -197,7 +192,6 @@ void paged_attention_v2(
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -33,6 +33,8 @@ namespace vec_op {
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
// Number of elements in single ASIMD vector of given Datatype
|
||||
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
}
|
||||
};
|
||||
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int32x4x4_t reg;
|
||||
int32_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int32x4x4_t reg;
|
||||
|
||||
explicit INT32Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
|
||||
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
|
||||
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
|
||||
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
|
||||
}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
vst1q_s32(ptr, reg.val[0]);
|
||||
vst1q_s32(ptr + 4, reg.val[1]);
|
||||
vst1q_s32(ptr + 8, reg.val[2]);
|
||||
vst1q_s32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s32(
|
||||
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
int32x4_t temp = reg.val[full_blocks];
|
||||
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
|
||||
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16& v) {
|
||||
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
|
||||
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
|
||||
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
|
||||
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
|
||||
};
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||
return FP32Vec16(float32x4x4_t(
|
||||
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
|
||||
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
|
||||
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
|
||||
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
|
||||
vmaxq_f32(b.reg.val[1], reg.val[1]),
|
||||
vmaxq_f32(b.reg.val[2], reg.val[2]),
|
||||
vmaxq_f32(b.reg.val[3], reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vminq_f32(b.reg.val[0], reg.val[0]),
|
||||
vminq_f32(b.reg.val[1], reg.val[1]),
|
||||
vminq_f32(b.reg.val[2], reg.val[2]),
|
||||
vminq_f32(b.reg.val[3], reg.val[3]),
|
||||
}));
|
||||
};
|
||||
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(
|
||||
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
|
||||
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
|
||||
}
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
float reduce_max() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float max_v = std::numeric_limits<float>::lowest();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
|
||||
return max_v;
|
||||
}
|
||||
|
||||
float reduce_min() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float min_v = std::numeric_limits<float>::max();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
|
||||
return min_v;
|
||||
}
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
vst1q_f32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_f32(
|
||||
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float32x4_t temp = reg.val[full_blocks];
|
||||
float* base = reinterpret_cast<float32_t*>(ptr) +
|
||||
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
|
||||
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int8x16_t reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int8x16_t reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) {
|
||||
// Convert each 128-bit float32 vector to int32
|
||||
int32x4_t part0 =
|
||||
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
|
||||
int32x4_t part1 =
|
||||
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
|
||||
int32x4_t part2 =
|
||||
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
|
||||
int32x4_t part3 =
|
||||
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
|
||||
|
||||
// Narrow each 32-bit vector to 8 bits and combine
|
||||
int8x8_t lower =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
|
||||
int8x8_t upper =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
|
||||
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
|
||||
}
|
||||
|
||||
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
|
||||
if (remainder > 0) {
|
||||
int8x16_t temp = reg;
|
||||
int8_t* base =
|
||||
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
|
||||
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
|
||||
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
|
||||
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
|
||||
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
|
||||
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
|
||||
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
|
||||
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
|
||||
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
|
||||
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -36,7 +36,7 @@ struct KernelVecType<c10::Half> {
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
@ -151,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
|
@ -4,10 +4,10 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#if defined(USE_ROCM) && defined(__GFX9__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
@ -153,7 +153,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -210,7 +210,7 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -288,7 +288,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
@ -195,7 +195,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -238,7 +238,7 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -295,7 +295,7 @@ struct ScaledEpilogueBiasAzp
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -371,7 +371,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
@ -45,7 +45,6 @@
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
|
@ -1,656 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
||||
const int dim = varlen ? sizes[0] : sizes[1];
|
||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(x, dim, seqlen);
|
||||
}
|
||||
else {
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
);
|
||||
|
||||
if (conv_states.has_value()) {
|
||||
auto conv_states_ = conv_states.value();
|
||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
||||
TORCH_CHECK(conv_states_.is_cuda());
|
||||
params.conv_states_ptr = conv_states_.data_ptr();
|
||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
||||
params.conv_states_c_stride = conv_states_.stride(1);
|
||||
params.conv_states_l_stride = conv_states_.stride(2);
|
||||
} else {
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
const int conv_state_len = conv_state.size(2);
|
||||
TORCH_CHECK(conv_state_len >= width - 1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (cache_seqlens_.has_value()) {
|
||||
auto cache_seqlens = cache_seqlens_.value();
|
||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
||||
} else {
|
||||
params.cache_seqlens = nullptr;
|
||||
}
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
||||
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t initial_state[kNElts] = {0};
|
||||
if (has_initial_state) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
||||
}
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < -final_state_position; ++w){
|
||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
||||
}
|
||||
}
|
||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
||||
conv_states[w] = vals_load[w + final_state_position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
// initial state which is stored in conv_states
|
||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
||||
// and load it into conv_state accordingly
|
||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
||||
if (conv_states != nullptr && tidx == last_thread) {
|
||||
input_t x_vals_load[kNElts * 2] = {0};
|
||||
// in case we are on the first kWidth tokens
|
||||
if (last_thread == 0 && seqlen < kWidth){
|
||||
// Need to take the initial state
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
||||
const int offset = seqlen - (kWidth - 1);
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
// pad the existing state
|
||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
if (offset + w >= 0)
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
||||
// illegal access error on H100.
|
||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kIsCircularBuffer>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
if (channel_id >= params.dim) return;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
|
||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
||||
if (conv_state_batch_coord == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
int state_len = params.conv_state_len;
|
||||
int advance_len = params.seqlen;
|
||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
||||
int update_idx = cache_seqlen - (kWidth - 1);
|
||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) {
|
||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
||||
}
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < params.seqlen; ++i) {
|
||||
input_t x_val = x[i * params.x_l_stride];
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
||||
}
|
||||
} else {
|
||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
||||
++update_idx;
|
||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
||||
}
|
||||
x_vals[kWidth - 1] = float(x_val);
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
out[i * params.out_l_stride] = input_t(out_val);
|
||||
// Shift the input buffer by 1
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = params.cache_seqlens == nullptr
|
||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
@ -1,159 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
int64_t pad_slot_id;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
int conv_state_len;
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
void *__restrict__ query_start_loc_ptr;
|
||||
void *__restrict__ has_initial_state_ptr;
|
||||
void *__restrict__ cache_indices_ptr;
|
||||
int32_t *__restrict__ cache_seqlens;
|
||||
|
||||
// For the continuous batching case. Makes it so that the mamba state for
|
||||
// the current batch doesn't need to be a contiguous tensor.
|
||||
int32_t *__restrict__ conv_state_indices_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
|
||||
void * conv_states_ptr;
|
||||
index_t conv_states_batch_stride;
|
||||
index_t conv_states_l_stride;
|
||||
index_t conv_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
@ -1,28 +0,0 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
@ -7,7 +7,11 @@
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
||||
#else
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
@ -312,19 +316,25 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifdef USE_ROCM
|
||||
C10_HIP_CHECK(hipFuncSetAttribute(
|
||||
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -612,19 +622,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
@ -653,4 +664,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
29
csrc/ops.h
29
csrc/ops.h
@ -167,6 +167,19 @@ void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table, double scale);
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out, torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe, torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens, torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace, double sm_scale,
|
||||
int64_t num_kv_splits =
|
||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(
|
||||
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0,
|
||||
int64_t num_kv_splits =
|
||||
1 /* Set to 1 to avoid cuda_graph issue by default. */);
|
||||
|
||||
torch::Tensor get_cuda_view_from_cpu_tensor(torch::Tensor& cpu_tensor);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
@ -326,22 +339,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
const auto v = fabsf(static_cast<float>(row_in[i]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(
|
||||
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||
const float v = fabsf(static_cast<float>(src));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
});
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||
@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
|
||||
// 1. calculate min & max
|
||||
MinMax thread_mm;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
thread_mm += static_cast<float>(row_in[i]);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
||||
[&] __device__(const scalar_t& src) {
|
||||
thread_mm += static_cast<float>(src);
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
|
@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
|
@ -0,0 +1,373 @@
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include <cassert>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator,
|
||||
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
__global__ void get_ggemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scale_offsets,
|
||||
ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scale_base_as_int,
|
||||
ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
|
||||
int m = problem_sizes[expert_id * 3];
|
||||
int n = problem_sizes[expert_id * 3 + 1];
|
||||
int k = problem_sizes[expert_id * 3 + 2];
|
||||
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
int a_stride = expert_offset * k;
|
||||
int b_stride = expert_id * k * n;
|
||||
int a_scale_stride = expert_offset * k / 128;
|
||||
int b_scale_stride = expert_id * k * n / 128 / 128;
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + a_stride;
|
||||
b_offsets[expert_id] = b_base_as_int + b_stride;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride;
|
||||
b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
*layout_sfb_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
|
||||
ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<int*>(problem_sizes.data_ptr())); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_ggemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& layout_sfa,
|
||||
torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
|
||||
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||
void run_blockwise_scaled_group_mm(
|
||||
torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs,
|
||||
const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs,
|
||||
const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Types
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*,
|
||||
AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA,
|
||||
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||
AlignmentA, ElementB,
|
||||
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue, void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(
|
||||
layout_sfa.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
int device_id = a_ptrs.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void blockwise_scaled_group_mm_dispatch_shape(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
struct MmaConfig {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||
1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
};
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
auto a_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto out_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto a_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
auto layout_sfa = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
auto layout_sfb = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
|
||||
auto stride_a = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_b = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_c = torch::full(
|
||||
{num_experts}, output.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
torch::TensorOptions options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
run_get_ggemm_starts<typename MmaConfig::LayoutSFA,
|
||||
typename MmaConfig::LayoutSFB,
|
||||
typename MmaConfig::ScaleConfig>(
|
||||
expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a,
|
||||
b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes);
|
||||
|
||||
run_blockwise_scaled_group_mm<OutType, MmaConfig,
|
||||
typename MmaConfig::LayoutC>(
|
||||
out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a,
|
||||
stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes,
|
||||
expert_offsets);
|
||||
}
|
||||
|
||||
void cutlass_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"a must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"b must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(output.scalar_type() == torch::kBFloat16 ||
|
||||
output.scalar_type() == torch::kHalf,
|
||||
"output must be bfloat16 or half");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32,
|
||||
"scales_a must be float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32,
|
||||
"scales_b must be float32");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
|
||||
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
|
||||
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
|
||||
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||
|
||||
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else if (output.scalar_type() == torch::kFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_blockwise_scaled_grouped_mm",
|
||||
&cutlass_blockwise_scaled_grouped_mm);
|
||||
}
|
@ -7,7 +7,7 @@
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
|
||||
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
|
||||
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
|
||||
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
}
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
|
@ -30,35 +30,40 @@
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
template <typename T>
|
||||
struct KernelTraits;
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
// Configuration for M in (256, inf)
|
||||
struct sm100_fp4_config_default {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in (16, 256]
|
||||
struct sm100_fp4_config_M256 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _128, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in [1, 16]
|
||||
struct sm100_fp4_config_M16 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::half_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::bfloat16_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename Config, typename OutType>
|
||||
struct Fp4GemmSm100 {
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using ElementD = OutType;
|
||||
using ElementC = OutType;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
||||
// Use config's tile shapes
|
||||
using MmaTileShape = typename Config::TileShape;
|
||||
using ClusterShape = typename Config::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Gemm::Arguments args_from_options(
|
||||
template <typename Config>
|
||||
typename Config::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename T::Gemm::ElementA;
|
||||
using ElementB = typename T::Gemm::ElementB;
|
||||
using ElementA = typename Config::Gemm::ElementA;
|
||||
using ElementB = typename Config::Gemm::ElementB;
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using ElementD = typename T::Gemm::ElementD;
|
||||
using ElementD = typename Config::Gemm::ElementD;
|
||||
using ElementCompute = float;
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using StrideA = typename Config::StrideA;
|
||||
using StrideB = typename Config::StrideB;
|
||||
using StrideD = typename Config::StrideD;
|
||||
using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
|
||||
CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
typename Config::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{// Mainloop arguments
|
||||
@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename Config>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
typename Config::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (16, 256]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
// m in (256, inf)
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||
k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||
m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
@ -38,7 +38,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/atom/copy_traits_sm90_tma.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
|
||||
|
@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
// Note: currently the output is guaranteed to be same as the input, so we
|
||||
// don't check it here, comments here just for future reference.
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
vec_op(tmp, v_in[i]);
|
||||
v_out[i] = tmp;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
||||
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
||||
@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
struct DefaultReadVecOp {
|
||||
ScaOp scalar_op;
|
||||
|
||||
__device__ __forceinline__ void operator()(
|
||||
const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
scalar_op(src.val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// read-only version: iterate over the input with alignment guarantees
|
||||
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
|
||||
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
|
||||
int tid, int stride,
|
||||
VecOp&& vec_op,
|
||||
ScaOp&& scalar_op) {
|
||||
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||
"VEC_SIZE must be a positive power-of-two");
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1);
|
||||
int alignment_bytes = WIDTH - misalignment_offset;
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1);
|
||||
prefix_elems /= sizeof(InT);
|
||||
prefix_elems = min(prefix_elems, len);
|
||||
|
||||
// 1. handle the possibly unaligned prefix with scalar access.
|
||||
for (int i = tid; i < prefix_elems; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
|
||||
in += prefix_elems;
|
||||
len -= prefix_elems;
|
||||
|
||||
int num_vec = len / VEC_SIZE;
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
// 2. vectorized traversal of the main aligned region.
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
|
||||
// 3. handle remaining tail elements.
|
||||
int tail_start = num_vec * VEC_SIZE;
|
||||
for (int i = tid + tail_start; i < len; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// overload that requires only a scalar_op
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
__device__ __forceinline__ void vectorize_read_with_alignment(
|
||||
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
|
||||
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
|
||||
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -59,6 +59,8 @@ void apply_repetition_penalties_(
|
||||
int vocab_size = logits.size(-1);
|
||||
int num_seqs = logits.size(0);
|
||||
|
||||
if (num_seqs == 0) return;
|
||||
|
||||
// Get number of SMs on the current device
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||
|
@ -79,7 +79,8 @@ struct cutlass_sparse_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
|
@ -393,6 +393,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass blockwise scaledgroup GEMM
|
||||
ops.def(
|
||||
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
|
||||
"Tensor scales_a, Tensor scales_b, "
|
||||
"Tensor problem_sizes, Tensor expert_offsets) -> ()",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
@ -506,6 +514,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, Tensor workspace, float "
|
||||
"scale,"
|
||||
" int num_kv_splits) -> ()");
|
||||
ops.impl("sm100_cutlass_mla_decode", torch::kCUDA, &sm100_cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA workspace
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||
" int sm_count, int num_kv_splits) "
|
||||
"-> int");
|
||||
ops.impl("sm100_cutlass_mla_get_workspace_size",
|
||||
&sm100_cutlass_mla_get_workspace_size);
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
@ -586,28 +611,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? cache_seqlens_,"
|
||||
"Tensor? conv_state_indices,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor!? conv_states,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
@ -62,12 +63,16 @@ ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
|
||||
ARG PIP_KEYRING_PROVIDER=disabled
|
||||
ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
|
||||
|
||||
# Flag enables build-in KV-connector dependency libs into docker images
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
@ -276,6 +281,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
@ -373,38 +379,38 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
|
||||
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.8rc1"
|
||||
# Flag to control whether to use pre-built FlashInfer wheels (set to false to force build from source)
|
||||
# TODO: Currently disabled because the pre-built wheels are not available for FLASHINFER_GIT_REF
|
||||
ARG USE_FLASHINFER_PREBUILT_WHEEL=false
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]] && [[ "$USE_FLASHINFER_PREBUILT_WHEEL" == "true" ]]; then
|
||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
|
||||
else
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
|
||||
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
)
|
||||
rm -rf flashinfer
|
||||
|
||||
# Default arches (skipping 10.0a and 12.0 since these need 12.8)
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
|
||||
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch v0.2.6.post1 \
|
||||
https://github.com/flashinfer-ai/flashinfer.git flashinfer
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation .
|
||||
popd
|
||||
|
||||
@ -485,6 +491,7 @@ RUN mv mkdocs.yaml test_docs/
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
@ -493,8 +500,13 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
COPY requirements/kv_connectors.txt requirements/kv_connectors.txt
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r requirements/kv_connectors.txt; \
|
||||
fi; \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
|
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="6487649"
|
||||
ARG AITER_BRANCH="916bf3c"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0'
|
||||
pip install accelerate hf_transfer pytest 'modelscope!=1.15.0'
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image \
|
||||
TRITON_XPU_PROFILE 1
|
||||
|
@ -55,6 +55,7 @@ nav:
|
||||
- contributing/model/registration.md
|
||||
- contributing/model/tests.md
|
||||
- contributing/model/multimodal.md
|
||||
- CI: contributing/ci
|
||||
- Design Documents:
|
||||
- V0: design
|
||||
- V1: design/v1
|
||||
|
@ -36,7 +36,7 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular HuggingFace models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
@ -48,4 +48,4 @@ For more information, check out the following:
|
||||
- [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention)
|
||||
- [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023)
|
||||
- [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al.
|
||||
- [vLLM Meetups][meetups]
|
||||
- [vLLM Meetups](community/meetups.md)
|
||||
|
@ -8,7 +8,6 @@ API documentation for vLLM's configuration classes.
|
||||
|
||||
- [vllm.config.ModelConfig][]
|
||||
- [vllm.config.CacheConfig][]
|
||||
- [vllm.config.TokenizerPoolConfig][]
|
||||
- [vllm.config.LoadConfig][]
|
||||
- [vllm.config.ParallelConfig][]
|
||||
- [vllm.config.SchedulerConfig][]
|
||||
@ -64,7 +63,7 @@ vLLM provides experimental support for multi-modal models through the [vllm.mult
|
||||
Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models]
|
||||
via the `multi_modal_data` field in [vllm.inputs.PromptType][].
|
||||
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here][supports-multimodal].
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here](../contributing/model/multimodal.md).
|
||||
|
||||
- [vllm.multimodal.MULTIMODAL_REGISTRY][]
|
||||
|
||||
|
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 84 KiB |
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 68 KiB |
@ -1,3 +1,7 @@
|
||||
---
|
||||
toc_depth: 4
|
||||
---
|
||||
|
||||
# vLLM CLI Guide
|
||||
|
||||
The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with:
|
||||
@ -16,7 +20,7 @@ vllm {chat,complete,serve,bench,collect-env,run-batch}
|
||||
|
||||
Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
??? Examples
|
||||
??? console "Examples"
|
||||
|
||||
```bash
|
||||
# Start with a model
|
||||
@ -37,8 +41,15 @@ Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
|
||||
# To view full help with pager (less/more)
|
||||
vllm serve --help=page
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
--8<-- "docs/argparse/serve.md"
|
||||
|
||||
## chat
|
||||
|
||||
Generate chat completions via the running API server.
|
||||
|
@ -1,6 +1,3 @@
|
||||
---
|
||||
title: Contact Us
|
||||
---
|
||||
[](){ #contactus }
|
||||
# Contact Us
|
||||
|
||||
--8<-- "README.md:contact-us"
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Meetups
|
||||
---
|
||||
[](){ #meetups }
|
||||
# Meetups
|
||||
|
||||
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
|
||||
|
||||
|
@ -33,7 +33,7 @@ Quantized models take less memory at the cost of lower precision.
|
||||
Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Red Hat AI](https://huggingface.co/RedHatAI))
|
||||
and used directly without extra configuration.
|
||||
|
||||
Dynamic quantization is also supported via the `quantization` option -- see [here][quantization-index] for more details.
|
||||
Dynamic quantization is also supported via the `quantization` option -- see [here](../features/quantization/README.md) for more details.
|
||||
|
||||
## Context length and batch size
|
||||
|
||||
@ -57,7 +57,7 @@ By default, we optimize model inference using CUDA graphs which take up extra me
|
||||
|
||||
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
@ -129,7 +129,7 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
|
||||
|
||||
Here are some examples:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
@ -1,18 +1,20 @@
|
||||
---
|
||||
title: Engine Arguments
|
||||
toc_depth: 3
|
||||
---
|
||||
[](){ #engine-args }
|
||||
|
||||
# Engine Arguments
|
||||
|
||||
Engine arguments control the behavior of the vLLM engine.
|
||||
|
||||
- For [offline inference][offline-inference], they are part of the arguments to [LLM][vllm.LLM] class.
|
||||
- For [online serving][serving-openai-compatible-server], they are part of the arguments to `vllm serve`.
|
||||
- For [offline inference](../serving/offline_inference.md), they are part of the arguments to [LLM][vllm.LLM] class.
|
||||
- For [online serving](../serving/openai_compatible_server.md), they are part of the arguments to `vllm serve`.
|
||||
|
||||
You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments.
|
||||
The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings.
|
||||
|
||||
However, these classes are a combination of the configuration classes defined in [vllm.config][]. Therefore, we would recommend you read about them there where they are best documented.
|
||||
## `EngineArgs`
|
||||
|
||||
For offline inference you will have access to these configuration classes and for online serving you can cross-reference the configs with `vllm serve --help`, which has its arguments grouped by config.
|
||||
--8<-- "docs/argparse/engine_args.md"
|
||||
|
||||
!!! note
|
||||
Additional arguments are available to the [AsyncLLMEngine][vllm.engine.async_llm_engine.AsyncLLMEngine] which is used for online serving. These can be found by running `vllm serve --help`
|
||||
## `AsyncEngineArgs`
|
||||
|
||||
--8<-- "docs/argparse/async_engine_args.md"
|
||||
|
@ -7,7 +7,7 @@ vLLM uses the following environment variables to configure the system:
|
||||
|
||||
All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables).
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/envs.py:env-vars-definition"
|
||||
|
@ -20,4 +20,4 @@ model = LLM(
|
||||
)
|
||||
```
|
||||
|
||||
Our [list of supported models][supported-models] shows the model architectures that are recognized by vLLM.
|
||||
Our [list of supported models](../models/supported_models.md) shows the model architectures that are recognized by vLLM.
|
||||
|
@ -1,19 +1,16 @@
|
||||
---
|
||||
title: Server Arguments
|
||||
---
|
||||
[](){ #serve-args }
|
||||
# Server Arguments
|
||||
|
||||
The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
|
||||
## CLI Arguments
|
||||
|
||||
The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
To see the available CLI arguments, run `vllm serve --help`!
|
||||
To see the available options, take a look at the [CLI Reference](../cli/README.md#options)!
|
||||
|
||||
## Configuration file
|
||||
|
||||
You can load CLI arguments via a [YAML](https://yaml.org/) config file.
|
||||
The argument names must be the long form of those outlined [above][serve-args].
|
||||
The argument names must be the long form of those outlined [above](serve_args.md).
|
||||
|
||||
For example:
|
||||
|
||||
|
@ -95,7 +95,7 @@ For additional features and advanced configurations, refer to the official [MkDo
|
||||
|
||||
## Testing
|
||||
|
||||
??? note "Commands"
|
||||
??? console "Commands"
|
||||
|
||||
```bash
|
||||
pip install -r requirements/dev.txt
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Benchmark Suites
|
||||
---
|
||||
[](){ #benchmarks }
|
||||
# Benchmark Suites
|
||||
|
||||
vLLM contains two sets of benchmarks:
|
||||
|
||||
|
@ -6,9 +6,9 @@ the failure?
|
||||
- Check the dashboard of current CI test failures:
|
||||
👉 [CI Failures Dashboard](https://github.com/orgs/vllm-project/projects/20)
|
||||
|
||||
- If your failure **is already listed**, it's likely unrelated to your PR.
|
||||
Help fixing it is always welcome!
|
||||
- Leave comments with links to additional instances of the failure.
|
||||
- If your failure **is already listed**, it's likely unrelated to your PR.
|
||||
Help fixing it is always welcome!
|
||||
- Leave comments with links to additional instances of the failure.
|
||||
- React with a 👍 to signal how many are affected.
|
||||
|
||||
- If your failure **is not listed**, you should **file an issue**.
|
||||
@ -19,25 +19,25 @@ the failure?
|
||||
👉 [New CI Failure Report](https://github.com/vllm-project/vllm/issues/new?template=450-ci-failure.yml)
|
||||
|
||||
- **Use this title format:**
|
||||
|
||||
|
||||
```
|
||||
[CI Failure]: failing-test-job - regex/matching/failing:test
|
||||
```
|
||||
|
||||
- **For the environment field:**
|
||||
|
||||
|
||||
```
|
||||
Still failing on main as of commit abcdef123
|
||||
```
|
||||
|
||||
- **In the description, include failing tests:**
|
||||
|
||||
|
||||
```
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
https://github.com/orgs/vllm-project/projects/20
|
||||
https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
https://github.com/orgs/vllm-project/projects/20
|
||||
https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
```
|
||||
|
||||
- **Attach logs** (collapsible section example):
|
||||
@ -45,17 +45,17 @@ the failure?
|
||||
<summary>Logs:</summary>
|
||||
|
||||
```text
|
||||
ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data
|
||||
ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data
|
||||
--- Logging error ---
|
||||
Traceback (most recent call last):
|
||||
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 203, in execute_model
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
return self.model_executor.execute_model(scheduler_output)
|
||||
...
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
FAILED failing/test.py:failing_test1 - Failure description
|
||||
FAILED failing/test.py:failing_test2 - Failure description
|
||||
FAILED failing/test.py:failing_test3 - Failure description
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
## Logs Wrangling
|
||||
@ -78,7 +78,7 @@ tail -525 ci_build.log | wl-copy
|
||||
|
||||
## Investigating a CI Test Failure
|
||||
|
||||
1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main)
|
||||
1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main)
|
||||
2. Bisect to find the first build that shows the issue.
|
||||
3. Add your findings to the GitHub issue.
|
||||
4. If you find a strong candidate PR, mention it in the issue and ping contributors.
|
||||
@ -97,9 +97,9 @@ CI test failures may be flaky. Use a bash loop to run repeatedly:
|
||||
|
||||
If you submit a PR to fix a CI failure:
|
||||
|
||||
- Link the PR to the issue:
|
||||
- Link the PR to the issue:
|
||||
Add `Closes #12345` to the PR description.
|
||||
- Add the `ci-failure` label:
|
||||
- Add the `ci-failure` label:
|
||||
This helps track it in the [CI Failures GitHub Project](https://github.com/orgs/vllm-project/projects/20).
|
||||
|
||||
## Other Resources
|
@ -1,15 +1,12 @@
|
||||
---
|
||||
title: Update PyTorch version on vLLM OSS CI/CD
|
||||
---
|
||||
# Update PyTorch version on vLLM OSS CI/CD
|
||||
|
||||
vLLM's current policy is to always use the latest PyTorch stable
|
||||
release in CI/CD. It is standard practice to submit a PR to update the
|
||||
PyTorch version as early as possible when a new [PyTorch stable
|
||||
release](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-cadence) becomes available.
|
||||
This process is non-trivial due to the gap between PyTorch
|
||||
releases. Using [#16859](https://github.com/vllm-project/vllm/pull/16859) as
|
||||
an example, this document outlines common steps to achieve this update along with
|
||||
a list of potential issues and how to address them.
|
||||
releases. Using <gh-pr:16859> as an example, this document outlines common steps to achieve this
|
||||
update along with a list of potential issues and how to address them.
|
||||
|
||||
## Test PyTorch release candidates (RCs)
|
||||
|
||||
@ -19,11 +16,12 @@ by waiting for the next release or by implementing hacky workarounds in vLLM.
|
||||
The better solution is to test vLLM with PyTorch release candidates (RC) to ensure
|
||||
compatibility before each release.
|
||||
|
||||
PyTorch release candidates can be downloaded from PyTorch test index at https://download.pytorch.org/whl/test.
|
||||
For example, torch2.7.0+cu12.8 RC can be installed using the following command:
|
||||
PyTorch release candidates can be downloaded from [PyTorch test index](https://download.pytorch.org/whl/test).
|
||||
For example, `torch2.7.0+cu12.8` RC can be installed using the following command:
|
||||
|
||||
```
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
|
||||
```bash
|
||||
uv pip install torch torchvision torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/test/cu128
|
||||
```
|
||||
|
||||
When the final RC is ready for testing, it will be announced to the community
|
||||
@ -31,13 +29,28 @@ on the [PyTorch dev-discuss forum](https://dev-discuss.pytorch.org/c/release-ann
|
||||
After this announcement, we can begin testing vLLM integration by drafting a pull request
|
||||
following this 3-step process:
|
||||
|
||||
1. Update requirements files in https://github.com/vllm-project/vllm/tree/main/requirements
|
||||
to point to the new releases for torch, torchvision, and torchaudio.
|
||||
2. Use `--extra-index-url https://download.pytorch.org/whl/test/<PLATFORM>` to
|
||||
get the final release candidates' wheels. Some common platforms are `cpu`, `cu128`,
|
||||
and `rocm6.2.4`.
|
||||
3. As vLLM uses uv, make sure that `unsafe-best-match` strategy is set either
|
||||
via `UV_INDEX_STRATEGY` env variable or via `--index-strategy unsafe-best-match`.
|
||||
1. Update [requirements files](https://github.com/vllm-project/vllm/tree/main/requirements)
|
||||
to point to the new releases for `torch`, `torchvision`, and `torchaudio`.
|
||||
|
||||
2. Use the following option to get the final release candidates' wheels. Some common platforms are `cpu`, `cu128`, and `rocm6.2.4`.
|
||||
|
||||
```bash
|
||||
--extra-index-url https://download.pytorch.org/whl/test/<PLATFORM>
|
||||
```
|
||||
|
||||
3. Since vLLM uses `uv`, ensure the following index strategy is applied:
|
||||
|
||||
- Via environment variable:
|
||||
|
||||
```bash
|
||||
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||
```
|
||||
|
||||
- Or via CLI flag:
|
||||
|
||||
```bash
|
||||
--index-strategy unsafe-best-match
|
||||
```
|
||||
|
||||
If failures are found in the pull request, raise them as issues on vLLM and
|
||||
cc the PyTorch release team to initiate discussion on how to address them.
|
||||
@ -45,20 +58,25 @@ cc the PyTorch release team to initiate discussion on how to address them.
|
||||
## Update CUDA version
|
||||
|
||||
The PyTorch release matrix includes both stable and experimental [CUDA versions](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix). Due to limitations, only the latest stable CUDA version (for example,
|
||||
torch2.7.0+cu12.6) is uploaded to PyPI. However, vLLM may require a different CUDA version,
|
||||
`torch2.7.0+cu12.6`) is uploaded to PyPI. However, vLLM may require a different CUDA version,
|
||||
such as 12.8 for Blackwell support.
|
||||
This complicates the process as we cannot use the out-of-the-box
|
||||
`pip install torch torchvision torchaudio` command. The solution is to use
|
||||
`--extra-index-url` in vLLM's Dockerfiles.
|
||||
|
||||
1. Use `--extra-index-url https://download.pytorch.org/whl/cu128` to install torch+cu128.
|
||||
2. Other important indexes at the moment include:
|
||||
1. CPU ‒ https://download.pytorch.org/whl/cpu
|
||||
2. ROCm ‒ https://download.pytorch.org/whl/rocm6.2.4 and https://download.pytorch.org/whl/rocm6.3
|
||||
3. XPU ‒ https://download.pytorch.org/whl/xpu
|
||||
3. Update .buildkite/release-pipeline.yaml and .buildkite/scripts/upload-wheels.sh to
|
||||
match the CUDA version from step 1. This makes sure that the release vLLM wheel is tested
|
||||
on CI.
|
||||
- Important indexes at the moment include:
|
||||
|
||||
| Platform | `--extra-index-url` |
|
||||
|----------|-----------------|
|
||||
| CUDA 12.8| [https://download.pytorch.org/whl/cu128](https://download.pytorch.org/whl/cu128)|
|
||||
| CPU | [https://download.pytorch.org/whl/cpu](https://download.pytorch.org/whl/cpu)|
|
||||
| ROCm 6.2 | [https://download.pytorch.org/whl/rocm6.2.4](https://download.pytorch.org/whl/rocm6.2.4) |
|
||||
| ROCm 6.3 | [https://download.pytorch.org/whl/rocm6.3](https://download.pytorch.org/whl/rocm6.3) |
|
||||
| XPU | [https://download.pytorch.org/whl/xpu](https://download.pytorch.org/whl/xpu) |
|
||||
|
||||
- Update the below files to match the CUDA version from step 1. This makes sure that the release vLLM wheel is tested on CI.
|
||||
- `.buildkite/release-pipeline.yaml`
|
||||
- `.buildkite/scripts/upload-wheels.sh`
|
||||
|
||||
## Address long vLLM build time
|
||||
|
||||
@ -68,8 +86,8 @@ and timeout. Additionally, since vLLM's fastcheck pipeline runs in read-only mod
|
||||
it doesn't populate the cache, so re-running it to warm up the cache
|
||||
is ineffective.
|
||||
|
||||
While ongoing efforts like [#17419](https://github.com/vllm-project/vllm/issues/17419)
|
||||
address the long build time at its source, the current workaround is to set VLLM_CI_BRANCH
|
||||
While ongoing efforts like [#17419](gh-issue:17419)
|
||||
address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH`
|
||||
to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`)
|
||||
when manually triggering a build on Buildkite. This branch accomplishes two things:
|
||||
|
||||
@ -89,17 +107,18 @@ releases (which would take too much time), they can be built from
|
||||
source to unblock the update process.
|
||||
|
||||
### FlashInfer
|
||||
Here is how to build and install it from source with torch2.7.0+cu128 in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271):
|
||||
Here is how to build and install it from source with `torch2.7.0+cu128` in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271):
|
||||
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
export FLASHINFER_ENABLE_SM90=1
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1"
|
||||
uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1"
|
||||
```
|
||||
|
||||
One caveat is that building FlashInfer from source adds approximately 30
|
||||
minutes to the vLLM build time. Therefore, it's preferable to cache the wheel in a
|
||||
public location for immediate installation, such as https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl. For future releases, contact the PyTorch release
|
||||
public location for immediate installation, such as [this FlashInfer wheel link](https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl). For future releases, contact the PyTorch release
|
||||
team if you want to get the package published there.
|
||||
|
||||
### xFormers
|
||||
@ -107,13 +126,15 @@ Similar to FlashInfer, here is how to build and install xFormers from source:
|
||||
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
MAX_JOBS=16 uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
|
||||
MAX_JOBS=16 uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
|
||||
```
|
||||
|
||||
### Mamba
|
||||
|
||||
```bash
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
uv pip install --system \
|
||||
--no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
```
|
||||
|
||||
### causal-conv1d
|
||||
@ -128,7 +149,6 @@ Rather than attempting to update all vLLM platforms in a single pull request, it
|
||||
to handle some platforms separately. The separation of requirements and Dockerfiles
|
||||
for different platforms in vLLM CI/CD allows us to selectively choose
|
||||
which platforms to update. For instance, updating XPU requires the corresponding
|
||||
release from https://github.com/intel/intel-extension-for-pytorch by Intel.
|
||||
While https://github.com/vllm-project/vllm/pull/16859 updated vLLM to PyTorch
|
||||
2.7.0 on CPU, CUDA, and ROCm, https://github.com/vllm-project/vllm/pull/17444
|
||||
completed the update for XPU.
|
||||
release from [Intel Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch) by Intel.
|
||||
While <gh-pr:16859> updated vLLM to PyTorch 2.7.0 on CPU, CUDA, and ROCm,
|
||||
<gh-pr:17444> completed the update for XPU.
|
@ -1,7 +1,7 @@
|
||||
# Dockerfile
|
||||
|
||||
We provide a <gh-file:docker/Dockerfile> to construct the image for running an OpenAI compatible server with vLLM.
|
||||
More information about deploying with Docker can be found [here][deployment-docker].
|
||||
More information about deploying with Docker can be found [here](../../deployment/docker.md).
|
||||
|
||||
Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes:
|
||||
|
||||
|
@ -14,7 +14,7 @@ Before setting up the incremental build:
|
||||
VLLM_USE_PRECOMPILED=1 uv pip install -U -e . --torch-backend=auto
|
||||
```
|
||||
|
||||
2. **CUDA Toolkit:** Verify that the NVIDIA CUDA Toolkit is correctly installed and `nvcc` is accessible in your `PATH`. CMake relies on `nvcc` to compile CUDA code. You can typically find `nvcc` in `$CUDA_HOME/bin/nvcc` or by running `which nvcc`. If you encounter issues, refer to the [official CUDA Toolkit installation guides](https://developer.nvidia.com/cuda-toolkit-archive) and vLLM's main [GPU installation documentation](../getting_started/installation/gpu/cuda.inc.md#troubleshooting) for troubleshooting. The `CMAKE_CUDA_COMPILER` variable in your `CMakeUserPresets.json` should also point to your `nvcc` binary.
|
||||
2. **CUDA Toolkit:** Verify that the NVIDIA CUDA Toolkit is correctly installed and `nvcc` is accessible in your `PATH`. CMake relies on `nvcc` to compile CUDA code. You can typically find `nvcc` in `$CUDA_HOME/bin/nvcc` or by running `which nvcc`. If you encounter issues, refer to the [official CUDA Toolkit installation guides](https://developer.nvidia.com/cuda-toolkit-archive) and vLLM's main [GPU installation documentation](../getting_started/installation/gpu.md#troubleshooting) for troubleshooting. The `CMAKE_CUDA_COMPILER` variable in your `CMakeUserPresets.json` should also point to your `nvcc` binary.
|
||||
|
||||
3. **Build Tools:** It is highly recommended to install `ccache` for fast rebuilds by caching compilation results (e.g., `sudo apt install ccache` or `conda install ccache`). Also, ensure the core build dependencies like `cmake` and `ninja` are installed. These are installable through `requirements/build.txt` or your system's package manager.
|
||||
|
||||
@ -84,6 +84,7 @@ Below is an example of what the generated `CMakeUserPresets.json` might look lik
|
||||
```
|
||||
|
||||
**What do the various configurations mean?**
|
||||
|
||||
- `CMAKE_CUDA_COMPILER`: Path to your `nvcc` binary. The script attempts to find this automatically.
|
||||
- `CMAKE_C_COMPILER_LAUNCHER`, `CMAKE_CXX_COMPILER_LAUNCHER`, `CMAKE_CUDA_COMPILER_LAUNCHER`: Setting these to `ccache` (or `sccache`) significantly speeds up rebuilds by caching compilation results. Ensure `ccache` is installed (e.g., `sudo apt install ccache` or `conda install ccache`). The script sets these by default.
|
||||
- `VLLM_PYTHON_EXECUTABLE`: Path to the Python executable in your vLLM development environment. The script will prompt for this, defaulting to the current Python environment if suitable.
|
||||
@ -98,16 +99,16 @@ Once your `CMakeUserPresets.json` is configured:
|
||||
1. **Initialize the CMake build environment:**
|
||||
This step configures the build system according to your chosen preset (e.g., `release`) and creates the build directory at `binaryDir`
|
||||
|
||||
```console
|
||||
cmake --preset release
|
||||
```
|
||||
```console
|
||||
cmake --preset release
|
||||
```
|
||||
|
||||
2. **Build and install the vLLM components:**
|
||||
This command compiles the code and installs the resulting binaries into your vLLM source directory, making them available to your editable Python installation.
|
||||
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
|
||||
3. **Make changes and repeat!**
|
||||
Now you start using your editable install of vLLM, testing and making changes as needed. If you need to build again to update based on changes, simply run the CMake command again to build only the affected files.
|
||||
|
@ -1,12 +1,9 @@
|
||||
---
|
||||
title: Summary
|
||||
---
|
||||
[](){ #new-model }
|
||||
# Summary
|
||||
|
||||
!!! important
|
||||
Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve <model>` works first!
|
||||
|
||||
vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features][compatibility-matrix] to optimize their performance.
|
||||
vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/compatibility_matrix.md) to optimize their performance.
|
||||
|
||||
The complexity of integrating a model into vLLM depends heavily on the model's architecture.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Basic Model
|
||||
---
|
||||
[](){ #new-model-basic }
|
||||
# Basic Model
|
||||
|
||||
This guide walks you through the steps to implement a basic vLLM model.
|
||||
|
||||
@ -27,7 +24,7 @@ All vLLM modules within the model must include a `prefix` argument in their cons
|
||||
|
||||
The initialization code should look like this:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from torch import nn
|
||||
@ -76,6 +73,8 @@ def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
@ -108,7 +107,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
|
||||
|
||||
## 5. Register your model
|
||||
|
||||
See [this page][new-model-registration] for instructions on how to register your new model to be used by vLLM.
|
||||
See [this page](registration.md) for instructions on how to register your new model to be used by vLLM.
|
||||
|
||||
## Frequently Asked Questions
|
||||
|
||||
|
@ -1,18 +1,15 @@
|
||||
---
|
||||
title: Multi-Modal Support
|
||||
---
|
||||
[](){ #supports-multimodal }
|
||||
# Multi-Modal Support
|
||||
|
||||
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs][multimodal-inputs].
|
||||
This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](../../features/multimodal_inputs.md).
|
||||
|
||||
## 1. Update the base vLLM model
|
||||
|
||||
It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic].
|
||||
It is assumed that you have already implemented the model in vLLM according to [these steps](basic.md).
|
||||
Further update the model as follows:
|
||||
|
||||
- Implement [get_placeholder_str][vllm.model_executor.models.interfaces.SupportsMultiModal.get_placeholder_str] to define the placeholder string which is used to represent the multi-modal item in the text prompt. This should be consistent with the chat template of the model.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
@ -41,7 +38,7 @@ Further update the model as follows:
|
||||
|
||||
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
@ -71,7 +68,7 @@ Further update the model as follows:
|
||||
|
||||
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
@ -155,7 +152,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Looking at the code of HF's `LlavaForConditionalGeneration`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
|
||||
@ -179,7 +176,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
The number of placeholder feature tokens per image is `image_features.shape[1]`.
|
||||
`image_features` is calculated inside the `get_image_features` method:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
|
||||
@ -217,7 +214,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
|
||||
@ -244,7 +241,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Overall, the number of placeholder feature tokens for an image can be calculated as:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_num_image_tokens(
|
||||
@ -269,7 +266,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
Notice that the number of image tokens doesn't depend on the image width and height.
|
||||
We can simply use a dummy `image_size` to calculate the multimodal profiling data:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# NOTE: In actuality, this is usually implemented as part of the
|
||||
@ -314,7 +311,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Looking at the code of HF's `FuyuForCausalLM`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322
|
||||
@ -344,7 +341,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`,
|
||||
returning the dimensions after resizing (but before padding) as metadata.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544
|
||||
@ -382,7 +379,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425
|
||||
@ -420,7 +417,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562
|
||||
@ -457,7 +454,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
For the multimodal image profiling data, the logic is very similar to LLaVA:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_dummy_mm_data(
|
||||
@ -483,7 +480,7 @@ Afterwards, create a subclass of [BaseMultiModalProcessor][vllm.multimodal.proce
|
||||
to fill in the missing details about HF processing.
|
||||
|
||||
!!! info
|
||||
[Multi-Modal Data Processing][mm-processing]
|
||||
[Multi-Modal Data Processing](../../design/mm_processing.md)
|
||||
|
||||
### Multi-modal fields
|
||||
|
||||
@ -546,7 +543,7 @@ return a schema of the tensors outputted by the HF processor that are related to
|
||||
In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA,
|
||||
we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def _call_hf_processor(
|
||||
@ -623,7 +620,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
||||
Based on this, we override [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] as follows:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
@ -668,7 +665,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
|
||||
We define a helper function to return `ncols` and `nrows` directly:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_image_feature_grid_size(
|
||||
@ -698,7 +695,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
|
||||
Based on this, we can initially define our replacement tokens as:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def get_replacement(item_idx: int):
|
||||
@ -718,7 +715,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called,
|
||||
a BOS token (`<s>`) is also added to the promopt:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435
|
||||
@ -745,7 +742,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
To assign the vision embeddings to only the image tokens, instead of a string
|
||||
you can return an instance of [PromptUpdateDetails][vllm.multimodal.processing.PromptUpdateDetails]:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
hf_config = self.info.get_hf_config()
|
||||
@ -772,7 +769,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt,
|
||||
we can search for it to conduct the replacement at the start of the string:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
@ -819,7 +816,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
After you have defined [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] (Step 2),
|
||||
[BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] (Step 3),
|
||||
and [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] (Step 4),
|
||||
decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor <vllm.multimodal.registry.MultiModalRegistry.register_processor>`
|
||||
decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.processing.MultiModalRegistry.register_processor]
|
||||
to register them to the multi-modal registry:
|
||||
|
||||
```diff
|
||||
@ -846,7 +843,7 @@ Examples:
|
||||
|
||||
### Handling prompt updates unrelated to multi-modal data
|
||||
|
||||
[_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override [_apply_hf_processor_tokens_only][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only] so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design][mm-processing].
|
||||
[_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override [_apply_hf_processor_tokens_only][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only] so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design](../../design/mm_processing.md).
|
||||
|
||||
Examples:
|
||||
|
||||
|
@ -1,10 +1,7 @@
|
||||
---
|
||||
title: Registering a Model
|
||||
---
|
||||
[](){ #new-model-registration }
|
||||
# Registering a Model
|
||||
|
||||
vLLM relies on a model registry to determine how to run each model.
|
||||
A list of pre-registered architectures can be found [here][supported-models].
|
||||
A list of pre-registered architectures can be found [here](../../models/supported_models.md).
|
||||
|
||||
If your model is not on this list, you must register it to vLLM.
|
||||
This page provides detailed instructions on how to do so.
|
||||
@ -14,16 +11,16 @@ This page provides detailed instructions on how to do so.
|
||||
To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source][build-from-source].
|
||||
This gives you the ability to modify the codebase and test your model.
|
||||
|
||||
After you have implemented your model (see [tutorial][new-model-basic]), put it into the <gh-dir:vllm/model_executor/models> directory.
|
||||
After you have implemented your model (see [tutorial](basic.md)), put it into the <gh-dir:vllm/model_executor/models> directory.
|
||||
Then, add your model class to `_VLLM_MODELS` in <gh-file:vllm/model_executor/models/registry.py> so that it is automatically registered upon importing vLLM.
|
||||
Finally, update our [list of supported models][supported-models] to promote your model!
|
||||
Finally, update our [list of supported models](../../models/supported_models.md) to promote your model!
|
||||
|
||||
!!! important
|
||||
The list of models in each section should be maintained in alphabetical order.
|
||||
|
||||
## Out-of-tree models
|
||||
|
||||
You can load an external model [using a plugin][plugin-system] without modifying the vLLM codebase.
|
||||
You can load an external model [using a plugin](../../design/plugin_system.md) without modifying the vLLM codebase.
|
||||
|
||||
To register the model, use the following code:
|
||||
|
||||
@ -51,4 +48,4 @@ def register():
|
||||
|
||||
!!! important
|
||||
If your model is a multimodal model, ensure the model class implements the [SupportsMultiModal][vllm.model_executor.models.interfaces.SupportsMultiModal] interface.
|
||||
Read more about that [here][supports-multimodal].
|
||||
Read more about that [here](multimodal.md).
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Unit Testing
|
||||
---
|
||||
[](){ #new-model-tests }
|
||||
# Unit Testing
|
||||
|
||||
This page explains how to write unit tests to verify the implementation of your model.
|
||||
|
||||
|
@ -125,7 +125,7 @@ to manually kill the profiler and generate your `nsys-rep` report.
|
||||
|
||||
You can view these profiles either as summaries in the CLI, using `nsys stats [profile-file]`, or in the GUI by installing Nsight [locally following the directions here](https://developer.nvidia.com/nsight-systems/get-started).
|
||||
|
||||
??? CLI example
|
||||
??? console "CLI example"
|
||||
|
||||
```bash
|
||||
nsys stats report1.nsys-rep
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Using Docker
|
||||
---
|
||||
[](){ #deployment-docker }
|
||||
# Using Docker
|
||||
|
||||
[](){ #deployment-docker-pre-built-image }
|
||||
|
||||
@ -32,7 +29,7 @@ podman run --gpus all \
|
||||
--model mistralai/Mistral-7B-v0.1
|
||||
```
|
||||
|
||||
You can add any other [engine-args][engine-args] you need after the image tag (`vllm/vllm-openai:latest`).
|
||||
You can add any other [engine-args](../configuration/engine_args.md) you need after the image tag (`vllm/vllm-openai:latest`).
|
||||
|
||||
!!! note
|
||||
You can either use the `ipc=host` flag or `--shm-size` flag to allow the
|
||||
@ -97,7 +94,7 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `--
|
||||
flags to speed up build process. However, ensure your `max_jobs` is substantially larger than `nvcc_threads` to get the most benefits.
|
||||
Keep an eye on memory usage with parallel jobs as it can be substantial (see example below).
|
||||
|
||||
??? Command
|
||||
??? console "Command"
|
||||
|
||||
```bash
|
||||
# Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB)
|
||||
|
8
docs/deployment/frameworks/anyscale.md
Normal file
8
docs/deployment/frameworks/anyscale.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Anyscale
|
||||
|
||||
[](){ #deployment-anyscale }
|
||||
|
||||
[Anyscale](https://www.anyscale.com) is a managed, multi-cloud platform developed by the creators of Ray.
|
||||
It hosts Ray clusters inside your own AWS, GCP, or Azure account, delivering the flexibility of open-source Ray
|
||||
without the operational overhead of maintaining Kubernetes control planes, configuring autoscalers, or managing observability stacks.
|
||||
When serving large language models with vLLM, Anyscale can rapidly provision [production-ready HTTPS endpoints](https://docs.anyscale.com/examples/deploy-ray-serve-llms) or [fault-tolerant batch inference jobs](https://docs.anyscale.com/examples/ray-data-llm).
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Anything LLM
|
||||
---
|
||||
[](){ #deployment-anything-llm }
|
||||
# Anything LLM
|
||||
|
||||
[Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting.
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: AutoGen
|
||||
---
|
||||
[](){ #deployment-autogen }
|
||||
# AutoGen
|
||||
|
||||
[AutoGen](https://github.com/microsoft/autogen) is a framework for creating multi-agent AI applications that can act autonomously or work alongside humans.
|
||||
|
||||
@ -30,7 +27,7 @@ python -m vllm.entrypoints.openai.api_server \
|
||||
|
||||
- Call it with AutoGen:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: BentoML
|
||||
---
|
||||
[](){ #deployment-bentoml }
|
||||
# BentoML
|
||||
|
||||
[BentoML](https://github.com/bentoml/BentoML) allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. You can serve the model locally or containerize it as an OCI-compliant image and deploy it on Kubernetes.
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Cerebrium
|
||||
---
|
||||
[](){ #deployment-cerebrium }
|
||||
# Cerebrium
|
||||
|
||||
<p align="center">
|
||||
<img src="https://i.ibb.co/hHcScTT/Screenshot-2024-06-13-at-10-14-54.png" alt="vLLM_plus_cerebrium"/>
|
||||
@ -34,7 +31,7 @@ vllm = "latest"
|
||||
|
||||
Next, let us add our code to handle inference for the LLM of your choice (`mistralai/Mistral-7B-Instruct-v0.1` for this example), add the following code to your `main.py`:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
@ -64,7 +61,7 @@ cerebrium deploy
|
||||
|
||||
If successful, you should be returned a CURL command that you can call inference against. Just remember to end the url with the function name you are calling (in our case`/run`)
|
||||
|
||||
??? Command
|
||||
??? console "Command"
|
||||
|
||||
```python
|
||||
curl -X POST https://api.cortex.cerebrium.ai/v4/p-xxxxxx/vllm/run \
|
||||
@ -82,7 +79,7 @@ If successful, you should be returned a CURL command that you can call inference
|
||||
|
||||
You should get a response like:
|
||||
|
||||
??? Response
|
||||
??? console "Response"
|
||||
|
||||
```python
|
||||
{
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Chatbox
|
||||
---
|
||||
[](){ #deployment-chatbox }
|
||||
# Chatbox
|
||||
|
||||
[Chatbox](https://github.com/chatboxai/chatbox) is a desktop client for LLMs, available on Windows, Mac, Linux.
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Dify
|
||||
---
|
||||
[](){ #deployment-dify }
|
||||
# Dify
|
||||
|
||||
[Dify](https://github.com/langgenius/dify) is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production.
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: dstack
|
||||
---
|
||||
[](){ #deployment-dstack }
|
||||
# dstack
|
||||
|
||||
<p align="center">
|
||||
<img src="https://i.ibb.co/71kx6hW/vllm-dstack.png" alt="vLLM_plus_dstack"/>
|
||||
@ -26,7 +23,7 @@ dstack init
|
||||
|
||||
Next, to provision a VM instance with LLM of your choice (`NousResearch/Llama-2-7b-chat-hf` for this example), create the following `serve.dstack.yml` file for the dstack `Service`:
|
||||
|
||||
??? Config
|
||||
??? code "Config"
|
||||
|
||||
```yaml
|
||||
type: service
|
||||
@ -48,7 +45,7 @@ Next, to provision a VM instance with LLM of your choice (`NousResearch/Llama-2-
|
||||
|
||||
Then, run the following CLI for provisioning:
|
||||
|
||||
??? Command
|
||||
??? console "Command"
|
||||
|
||||
```console
|
||||
$ dstack run . -f serve.dstack.yml
|
||||
@ -79,7 +76,7 @@ Then, run the following CLI for provisioning:
|
||||
|
||||
After the provisioning, you can interact with the model by using the OpenAI SDK:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Haystack
|
||||
---
|
||||
[](){ #deployment-haystack }
|
||||
# Haystack
|
||||
|
||||
# Haystack
|
||||
|
||||
@ -27,7 +24,7 @@ vllm serve mistralai/Mistral-7B-Instruct-v0.1
|
||||
|
||||
- Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server.
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
from haystack.components.generators.chat import OpenAIChatGenerator
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Helm
|
||||
---
|
||||
[](){ #deployment-helm }
|
||||
# Helm
|
||||
|
||||
A Helm chart to deploy vLLM for Kubernetes
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: LiteLLM
|
||||
---
|
||||
[](){ #deployment-litellm }
|
||||
# LiteLLM
|
||||
|
||||
[LiteLLM](https://github.com/BerriAI/litellm) call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, Groq etc.]
|
||||
|
||||
@ -34,7 +31,7 @@ vllm serve qwen/Qwen1.5-0.5B-Chat
|
||||
|
||||
- Call it with litellm:
|
||||
|
||||
??? Code
|
||||
??? code
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Lobe Chat
|
||||
---
|
||||
[](){ #deployment-lobe-chat }
|
||||
# Lobe Chat
|
||||
|
||||
[Lobe Chat](https://github.com/lobehub/lobe-chat) is an open-source, modern-design ChatGPT/LLMs UI/Framework.
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: LWS
|
||||
---
|
||||
[](){ #deployment-lws }
|
||||
# LWS
|
||||
|
||||
LeaderWorkerSet (LWS) is a Kubernetes API that aims to address common deployment patterns of AI/ML inference workloads.
|
||||
A major use case is for multi-host/multi-node distributed inference.
|
||||
@ -17,7 +14,7 @@ vLLM can be deployed with [LWS](https://github.com/kubernetes-sigs/lws) on Kuber
|
||||
|
||||
Deploy the following yaml file `lws.yaml`
|
||||
|
||||
??? Yaml
|
||||
??? code "Yaml"
|
||||
|
||||
```yaml
|
||||
apiVersion: leaderworkerset.x-k8s.io/v1
|
||||
@ -177,7 +174,7 @@ curl http://localhost:8080/v1/completions \
|
||||
|
||||
The output should be similar to the following
|
||||
|
||||
??? Output
|
||||
??? console "Output"
|
||||
|
||||
```text
|
||||
{
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Modal
|
||||
---
|
||||
[](){ #deployment-modal }
|
||||
# Modal
|
||||
|
||||
vLLM can be run on cloud GPUs with [Modal](https://modal.com), a serverless computing platform designed for fast auto-scaling.
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
---
|
||||
title: Open WebUI
|
||||
---
|
||||
[](){ #deployment-open-webui }
|
||||
# Open WebUI
|
||||
|
||||
1. Install the [Docker](https://docs.docker.com/engine/install/)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user